#![forbid(unsafe_code)]
use gatekeep::{
Condition, Context, FactId, LowerError, Lowered, QueryLowering, ResidualPolicy,
ResidualPolicyBranch, ResidualPolicyNode,
};
mod fragment;
pub use fragment::{PgFragment, PgValue};
pub trait PgFactPredicates {
fn predicate(&self, fact: &FactId, cx: &Context) -> Option<PgFragment>;
}
pub trait SqlOutcome {
fn to_sql_ordinal(&self) -> i64;
}
impl SqlOutcome for () {
fn to_sql_ordinal(&self) -> i64 {
0
}
}
pub trait OutcomeProjection<O> {
fn constant(&self, outcome: &O) -> Result<PgFragment, LowerError>;
}
#[derive(Clone, Copy, Debug, Default)]
pub struct OrdinalProjection;
impl<O: SqlOutcome> OutcomeProjection<O> for OrdinalProjection {
fn constant(&self, outcome: &O) -> Result<PgFragment, LowerError> {
Ok(PgFragment::bind(outcome.to_sql_ordinal()))
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct NoGradeProjection;
impl<O> OutcomeProjection<O> for NoGradeProjection {
fn constant(&self, _outcome: &O) -> Result<PgFragment, LowerError> {
Err(LowerError::NonTotalGrade)
}
}
#[derive(Clone, Debug)]
pub struct PgLowerer<P, M = OrdinalProjection> {
predicates: P,
projection: M,
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct PgLowered {
filter: PgFragment,
grade: PgFragment,
}
impl<P> PgLowerer<P, OrdinalProjection> {
#[must_use]
pub const fn new(predicates: P) -> Self {
Self::with_projection(predicates, OrdinalProjection)
}
}
impl<P, M> PgLowerer<P, M> {
#[must_use]
pub const fn with_projection(predicates: P, projection: M) -> Self {
Self {
predicates,
projection,
}
}
pub fn lower_filter<O>(
&self,
residual: &ResidualPolicy<O>,
cx: &Context,
) -> Result<PgFragment, LowerError>
where
P: PgFactPredicates,
{
residual.try_fold_pruned(
&mut |branch| match branch {
ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
!fallback.carries_obligation()
}
},
&mut |node| self.lower_filter_node(node, cx),
)
}
fn lower_filter_node<O>(
&self,
node: ResidualPolicyNode<'_, O, PgFragment>,
cx: &Context,
) -> Result<PgFragment, LowerError>
where
P: PgFactPredicates,
{
match node {
ResidualPolicyNode::Permit(_) | ResidualPolicyNode::PermitWithTrace { .. } => {
Ok(PgFragment::trusted("TRUE"))
}
ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
Ok(PgFragment::trusted("FALSE"))
}
ResidualPolicyNode::Grant { condition, .. } => self.lower_condition(condition, cx),
ResidualPolicyNode::All { arms, .. } => Ok(fragment_set(arms, " AND ", "FALSE")),
ResidualPolicyNode::Any { arms, .. } => Ok(fragment_set(arms, " OR ", "FALSE")),
ResidualPolicyNode::OrElse {
fallback_policy,
primary,
fallback,
..
} => {
if fallback_policy.carries_obligation() {
Ok(primary)
} else {
Ok(match fallback {
Some(fallback) => PgFragment::binary(" OR ", vec![primary, fallback]),
None => primary,
})
}
}
}
}
fn lower_condition(&self, condition: &Condition, cx: &Context) -> Result<PgFragment, LowerError>
where
P: PgFactPredicates,
{
match condition {
Condition::Always => Ok(PgFragment::trusted("TRUE")),
Condition::Never => Ok(PgFragment::trusted("FALSE")),
Condition::Has(fact) => self
.predicates
.predicate(fact, cx)
.map(is_true)
.ok_or_else(|| LowerError::Unlowerable(fact.clone())),
Condition::Not(inner) => {
Ok(PgFragment::unary("NOT ", self.lower_condition(inner, cx)?))
}
Condition::All(conditions) => {
lower_condition_set(conditions, " AND ", "FALSE", |item| {
self.lower_condition(item, cx)
})
}
Condition::Any(conditions) => {
lower_condition_set(conditions, " OR ", "FALSE", |item| {
self.lower_condition(item, cx)
})
}
}
}
fn lower_policy<O>(
&self,
residual: &ResidualPolicy<O>,
cx: &Context,
) -> Result<PgLowered, LowerError>
where
P: PgFactPredicates,
M: OutcomeProjection<O>,
{
residual.try_fold_pruned(
&mut |branch| match branch {
ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
!fallback.carries_obligation()
}
},
&mut |node| self.lower_node(node, cx),
)
}
fn lower_node<O>(
&self,
node: ResidualPolicyNode<'_, O, PgLowered>,
cx: &Context,
) -> Result<PgLowered, LowerError>
where
P: PgFactPredicates,
M: OutcomeProjection<O>,
{
match node {
ResidualPolicyNode::Permit(outcome)
| ResidualPolicyNode::PermitWithTrace { outcome, .. } => Ok(PgLowered {
filter: PgFragment::trusted("TRUE"),
grade: self.projection.constant(outcome)?,
}),
ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => Ok(PgLowered {
filter: PgFragment::trusted("FALSE"),
grade: PgFragment::trusted("NULL"),
}),
ResidualPolicyNode::Grant {
outcome, condition, ..
} => {
let filter = self.lower_condition(condition, cx)?;
let outcome = self.projection.constant(outcome)?;
Ok(PgLowered {
filter: filter.clone(),
grade: case_when(filter, outcome, PgFragment::trusted("NULL")),
})
}
ResidualPolicyNode::All { arms, .. } => {
let (filters, grades) = unzip_lowered(arms);
Ok(PgLowered {
filter: fragment_set(filters, " AND ", "FALSE"),
grade: grade_set(grades, "LEAST"),
})
}
ResidualPolicyNode::Any { arms, .. } => {
let (filters, grades) = unzip_lowered(arms);
Ok(PgLowered {
filter: fragment_set(filters, " OR ", "FALSE"),
grade: grade_set(grades, "GREATEST"),
})
}
ResidualPolicyNode::OrElse {
fallback_policy,
primary,
fallback,
..
} => {
if fallback_policy.carries_obligation() {
return Ok(primary);
}
Ok(match fallback {
Some(fallback) => PgLowered {
filter: PgFragment::binary(
" OR ",
vec![primary.filter.clone(), fallback.filter],
),
grade: case_when(primary.filter, primary.grade, fallback.grade),
},
None => primary,
})
}
}
}
}
impl<O, P, M> QueryLowering<O> for PgLowerer<P, M>
where
P: PgFactPredicates,
M: OutcomeProjection<O>,
{
type Filter = PgFragment;
type Projection = PgFragment;
fn lower(
&self,
residual: &ResidualPolicy<O>,
cx: &Context,
) -> Result<Lowered<Self::Filter, Self::Projection>, LowerError> {
let lowered = self.lower_policy(residual, cx)?;
Ok(Lowered {
filter: lowered.filter,
grade: lowered.grade,
})
}
}
fn lower_condition_set(
conditions: &[Condition],
separator: &str,
empty: &str,
lower: impl FnMut(&Condition) -> Result<PgFragment, LowerError>,
) -> Result<PgFragment, LowerError> {
if conditions.is_empty() {
return Ok(PgFragment::trusted(empty));
}
let fragments = conditions
.iter()
.map(lower)
.collect::<Result<Vec<_>, _>>()?;
Ok(PgFragment::binary(separator, fragments))
}
fn fragment_set(fragments: Vec<PgFragment>, separator: &str, empty: &str) -> PgFragment {
if fragments.is_empty() {
PgFragment::trusted(empty)
} else {
PgFragment::binary(separator, fragments)
}
}
fn grade_set(grades: Vec<PgFragment>, function: &str) -> PgFragment {
if grades.is_empty() {
PgFragment::trusted("NULL")
} else {
PgFragment::function(function, grades)
}
}
fn unzip_lowered(lowered: Vec<PgLowered>) -> (Vec<PgFragment>, Vec<PgFragment>) {
lowered
.into_iter()
.map(|lowered| (lowered.filter, lowered.grade))
.unzip()
}
fn case_when(condition: PgFragment, then_expr: PgFragment, else_expr: PgFragment) -> PgFragment {
let mut fragment = PgFragment::trusted("CASE WHEN ");
fragment.push_fragment(condition);
fragment.push_sql(" THEN ");
fragment.push_fragment(then_expr);
fragment.push_sql(" ELSE ");
fragment.push_fragment(else_expr);
fragment.push_sql(" END");
fragment
}
fn is_true(predicate: PgFragment) -> PgFragment {
let mut fragment = PgFragment::trusted("(");
fragment.push_fragment(predicate);
fragment.push_sql(") IS TRUE");
fragment
}