1#![forbid(unsafe_code)]
7
8use gatekeep::{
9 Condition, Context, FactId, LowerError, Lowered, QueryLowering, ResidualPolicy,
10 ResidualPolicyBranch, ResidualPolicyNode,
11};
12
13mod fragment;
14
15pub use fragment::{PgFragment, PgValue};
16
17pub trait PgFactPredicates {
19 fn predicate(&self, fact: &FactId, cx: &Context) -> Option<PgFragment>;
22}
23
24pub trait SqlOutcome {
26 fn to_sql_ordinal(&self) -> i64;
28}
29
30impl SqlOutcome for () {
31 fn to_sql_ordinal(&self) -> i64 {
32 0
33 }
34}
35
36pub trait OutcomeProjection<O> {
38 fn constant(&self, outcome: &O) -> Result<PgFragment, LowerError>;
40}
41
42#[derive(Clone, Copy, Debug, Default)]
44pub struct OrdinalProjection;
45
46impl<O: SqlOutcome> OutcomeProjection<O> for OrdinalProjection {
47 fn constant(&self, outcome: &O) -> Result<PgFragment, LowerError> {
48 Ok(PgFragment::bind(outcome.to_sql_ordinal()))
49 }
50}
51
52#[derive(Clone, Copy, Debug, Default)]
54pub struct NoGradeProjection;
55
56impl<O> OutcomeProjection<O> for NoGradeProjection {
57 fn constant(&self, _outcome: &O) -> Result<PgFragment, LowerError> {
58 Err(LowerError::NonTotalGrade)
59 }
60}
61
62#[derive(Clone, Debug)]
64pub struct PgLowerer<P, M = OrdinalProjection> {
65 predicates: P,
66 projection: M,
67}
68
69#[derive(Clone, Debug, PartialEq, Eq)]
70struct PgLowered {
71 filter: PgFragment,
72 grade: PgFragment,
73}
74
75impl<P> PgLowerer<P, OrdinalProjection> {
76 #[must_use]
78 pub const fn new(predicates: P) -> Self {
79 Self::with_projection(predicates, OrdinalProjection)
80 }
81}
82
83impl<P, M> PgLowerer<P, M> {
84 #[must_use]
86 pub const fn with_projection(predicates: P, projection: M) -> Self {
87 Self {
88 predicates,
89 projection,
90 }
91 }
92
93 pub fn lower_filter<O>(
95 &self,
96 residual: &ResidualPolicy<O>,
97 cx: &Context,
98 ) -> Result<PgFragment, LowerError>
99 where
100 P: PgFactPredicates,
101 {
102 residual.try_fold_pruned(
103 &mut |branch| match branch {
104 ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
105 !fallback.carries_obligation()
106 }
107 },
108 &mut |node| self.lower_filter_node(node, cx),
109 )
110 }
111
112 fn lower_filter_node<O>(
113 &self,
114 node: ResidualPolicyNode<'_, O, PgFragment>,
115 cx: &Context,
116 ) -> Result<PgFragment, LowerError>
117 where
118 P: PgFactPredicates,
119 {
120 match node {
121 ResidualPolicyNode::Permit(_) | ResidualPolicyNode::PermitWithTrace { .. } => {
122 Ok(PgFragment::trusted("TRUE"))
123 }
124 ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
125 Ok(PgFragment::trusted("FALSE"))
126 }
127 ResidualPolicyNode::Grant { condition, .. } => self.lower_condition(condition, cx),
128 ResidualPolicyNode::All { arms, .. } => Ok(fragment_set(arms, " AND ", "FALSE")),
129 ResidualPolicyNode::Any { arms, .. } => Ok(fragment_set(arms, " OR ", "FALSE")),
130 ResidualPolicyNode::OrElse {
131 fallback_policy,
132 primary,
133 fallback,
134 ..
135 } => {
136 if fallback_policy.carries_obligation() {
137 Ok(primary)
138 } else {
139 Ok(match fallback {
140 Some(fallback) => PgFragment::binary(" OR ", vec![primary, fallback]),
141 None => primary,
142 })
143 }
144 }
145 }
146 }
147
148 fn lower_condition(&self, condition: &Condition, cx: &Context) -> Result<PgFragment, LowerError>
149 where
150 P: PgFactPredicates,
151 {
152 match condition {
153 Condition::Always => Ok(PgFragment::trusted("TRUE")),
154 Condition::Never => Ok(PgFragment::trusted("FALSE")),
155 Condition::Has(fact) => self
156 .predicates
157 .predicate(fact, cx)
158 .map(is_true)
159 .ok_or_else(|| LowerError::Unlowerable(fact.clone())),
160 Condition::Not(inner) => {
161 Ok(PgFragment::unary("NOT ", self.lower_condition(inner, cx)?))
162 }
163 Condition::All(conditions) => {
164 lower_condition_set(conditions, " AND ", "FALSE", |item| {
165 self.lower_condition(item, cx)
166 })
167 }
168 Condition::Any(conditions) => {
169 lower_condition_set(conditions, " OR ", "FALSE", |item| {
170 self.lower_condition(item, cx)
171 })
172 }
173 }
174 }
175
176 fn lower_policy<O>(
177 &self,
178 residual: &ResidualPolicy<O>,
179 cx: &Context,
180 ) -> Result<PgLowered, LowerError>
181 where
182 P: PgFactPredicates,
183 M: OutcomeProjection<O>,
184 {
185 residual.try_fold_pruned(
186 &mut |branch| match branch {
187 ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
188 !fallback.carries_obligation()
189 }
190 },
191 &mut |node| self.lower_node(node, cx),
192 )
193 }
194
195 fn lower_node<O>(
196 &self,
197 node: ResidualPolicyNode<'_, O, PgLowered>,
198 cx: &Context,
199 ) -> Result<PgLowered, LowerError>
200 where
201 P: PgFactPredicates,
202 M: OutcomeProjection<O>,
203 {
204 match node {
205 ResidualPolicyNode::Permit(outcome)
206 | ResidualPolicyNode::PermitWithTrace { outcome, .. } => Ok(PgLowered {
207 filter: PgFragment::trusted("TRUE"),
208 grade: self.projection.constant(outcome)?,
209 }),
210 ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => Ok(PgLowered {
211 filter: PgFragment::trusted("FALSE"),
212 grade: PgFragment::trusted("NULL"),
213 }),
214 ResidualPolicyNode::Grant {
215 outcome, condition, ..
216 } => {
217 let filter = self.lower_condition(condition, cx)?;
218 let outcome = self.projection.constant(outcome)?;
219 Ok(PgLowered {
220 filter: filter.clone(),
221 grade: case_when(filter, outcome, PgFragment::trusted("NULL")),
222 })
223 }
224 ResidualPolicyNode::All { arms, .. } => {
225 let (filters, grades) = unzip_lowered(arms);
226 Ok(PgLowered {
227 filter: fragment_set(filters, " AND ", "FALSE"),
228 grade: grade_set(grades, "LEAST"),
229 })
230 }
231 ResidualPolicyNode::Any { arms, .. } => {
232 let (filters, grades) = unzip_lowered(arms);
233 Ok(PgLowered {
234 filter: fragment_set(filters, " OR ", "FALSE"),
235 grade: grade_set(grades, "GREATEST"),
236 })
237 }
238 ResidualPolicyNode::OrElse {
239 fallback_policy,
240 primary,
241 fallback,
242 ..
243 } => {
244 if fallback_policy.carries_obligation() {
245 return Ok(primary);
246 }
247
248 Ok(match fallback {
249 Some(fallback) => PgLowered {
250 filter: PgFragment::binary(
251 " OR ",
252 vec![primary.filter.clone(), fallback.filter],
253 ),
254 grade: case_when(primary.filter, primary.grade, fallback.grade),
255 },
256 None => primary,
257 })
258 }
259 }
260 }
261}
262
263impl<O, P, M> QueryLowering<O> for PgLowerer<P, M>
264where
265 P: PgFactPredicates,
266 M: OutcomeProjection<O>,
267{
268 type Filter = PgFragment;
269 type Projection = PgFragment;
270
271 fn lower(
272 &self,
273 residual: &ResidualPolicy<O>,
274 cx: &Context,
275 ) -> Result<Lowered<Self::Filter, Self::Projection>, LowerError> {
276 let lowered = self.lower_policy(residual, cx)?;
277 Ok(Lowered {
278 filter: lowered.filter,
279 grade: lowered.grade,
280 })
281 }
282}
283
284fn lower_condition_set(
285 conditions: &[Condition],
286 separator: &str,
287 empty: &str,
288 lower: impl FnMut(&Condition) -> Result<PgFragment, LowerError>,
289) -> Result<PgFragment, LowerError> {
290 if conditions.is_empty() {
291 return Ok(PgFragment::trusted(empty));
292 }
293 let fragments = conditions
294 .iter()
295 .map(lower)
296 .collect::<Result<Vec<_>, _>>()?;
297 Ok(PgFragment::binary(separator, fragments))
298}
299
300fn fragment_set(fragments: Vec<PgFragment>, separator: &str, empty: &str) -> PgFragment {
301 if fragments.is_empty() {
302 PgFragment::trusted(empty)
303 } else {
304 PgFragment::binary(separator, fragments)
305 }
306}
307
308fn grade_set(grades: Vec<PgFragment>, function: &str) -> PgFragment {
309 if grades.is_empty() {
310 PgFragment::trusted("NULL")
311 } else {
312 PgFragment::function(function, grades)
313 }
314}
315
316fn unzip_lowered(lowered: Vec<PgLowered>) -> (Vec<PgFragment>, Vec<PgFragment>) {
317 lowered
318 .into_iter()
319 .map(|lowered| (lowered.filter, lowered.grade))
320 .unzip()
321}
322
323fn case_when(condition: PgFragment, then_expr: PgFragment, else_expr: PgFragment) -> PgFragment {
324 let mut fragment = PgFragment::trusted("CASE WHEN ");
325 fragment.push_fragment(condition);
326 fragment.push_sql(" THEN ");
327 fragment.push_fragment(then_expr);
328 fragment.push_sql(" ELSE ");
329 fragment.push_fragment(else_expr);
330 fragment.push_sql(" END");
331 fragment
332}
333
334fn is_true(predicate: PgFragment) -> PgFragment {
335 let mut fragment = PgFragment::trusted("(");
336 fragment.push_fragment(predicate);
337 fragment.push_sql(") IS TRUE");
338 fragment
339}