use super::auth_context::AuthContext;
use super::predicate::{CompareOp, PolicyMode, PredicateValue, RlsPredicate};
use crate::bridge::scan_filter::ScanFilter;
pub fn substitute_to_scan_filters(
predicate: &RlsPredicate,
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
match predicate {
RlsPredicate::AlwaysTrue => Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}]),
RlsPredicate::AlwaysFalse => {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
}
RlsPredicate::Compare { field, op, value } => {
let resolved = match value {
PredicateValue::Literal(v) => v.clone(),
PredicateValue::AuthRef(auth_field) => auth.resolve_variable(auth_field)?,
PredicateValue::AuthFunc { .. } => value.resolve(auth)?,
PredicateValue::Field(_) => {
return None;
}
};
Some(vec![ScanFilter {
field: field.clone(),
op: op.as_filter_op().into(),
value: resolved,
clauses: Vec::new(),
}])
}
RlsPredicate::Contains { set, element } => substitute_contains(set, element, auth),
RlsPredicate::Intersects { left, right } => substitute_intersects(left, right, auth),
RlsPredicate::And(children) => {
let mut combined = Vec::new();
for child in children {
combined.extend(substitute_to_scan_filters(child, auth)?);
}
Some(combined)
}
RlsPredicate::Or(children) => {
let mut clause_groups: Vec<Vec<ScanFilter>> = Vec::new();
for child in children {
if let Some(filters) = substitute_to_scan_filters(child, auth) {
if filters.len() == 1 && filters[0].op == "match_all" {
return Some(filters);
}
clause_groups.push(filters);
}
}
if clause_groups.is_empty() {
return Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}]);
}
if clause_groups.len() == 1 {
return Some(clause_groups.into_iter().next().unwrap_or_default());
}
Some(vec![ScanFilter {
field: String::new(),
op: "or".into(),
value: serde_json::Value::Null,
clauses: clause_groups,
}])
}
RlsPredicate::Not(inner) => substitute_not(inner, auth),
}
}
pub fn combine_policies(
policies: &[(RlsPredicate, PolicyMode)],
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
if policies.is_empty() {
return Some(Vec::new()); }
let mut permissive: Vec<&RlsPredicate> = Vec::new();
let mut restrictive: Vec<&RlsPredicate> = Vec::new();
for (pred, mode) in policies {
match mode {
PolicyMode::Permissive => permissive.push(pred),
PolicyMode::Restrictive => restrictive.push(pred),
}
}
let mut combined = Vec::new();
if permissive.len() == 1 {
combined.extend(substitute_to_scan_filters(permissive[0], auth)?);
} else if permissive.len() > 1 {
let or_children: Vec<RlsPredicate> = permissive.iter().map(|p| (*p).clone()).collect();
let or_pred = RlsPredicate::Or(or_children);
combined.extend(substitute_to_scan_filters(&or_pred, auth)?);
}
for pred in &restrictive {
combined.extend(substitute_to_scan_filters(pred, auth)?);
}
Some(combined)
}
fn substitute_contains(
set: &PredicateValue,
element: &PredicateValue,
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
match (set, element) {
(PredicateValue::AuthRef(auth_field), PredicateValue::Literal(lit)) => {
let auth_val = auth.resolve_variable(auth_field)?;
if let Some(arr) = auth_val.as_array() {
if arr.contains(lit) {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
}
} else {
None }
}
(PredicateValue::AuthFunc { .. }, PredicateValue::Literal(lit)) => {
let auth_val = set.resolve(auth)?;
if let Some(arr) = auth_val.as_array() {
if arr.contains(lit) {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
}
} else {
None }
}
(PredicateValue::Field(doc_field), PredicateValue::AuthRef(auth_field)) => {
let auth_val = auth.resolve_variable(auth_field)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "contains".into(),
value: auth_val,
clauses: Vec::new(),
}])
}
(PredicateValue::Field(doc_field), PredicateValue::AuthFunc { .. }) => {
let auth_val = element.resolve(auth)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "contains".into(),
value: auth_val,
clauses: Vec::new(),
}])
}
(PredicateValue::Field(doc_field), PredicateValue::Literal(lit)) => {
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "contains".into(),
value: lit.clone(),
clauses: Vec::new(),
}])
}
_ => None, }
}
fn substitute_intersects(
left: &PredicateValue,
right: &PredicateValue,
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
match (left, right) {
(PredicateValue::Field(doc_field), PredicateValue::AuthRef(auth_field))
| (PredicateValue::AuthRef(auth_field), PredicateValue::Field(doc_field)) => {
let auth_val = auth.resolve_variable(auth_field)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "any_in".into(),
value: auth_val,
clauses: Vec::new(),
}])
}
(PredicateValue::Field(doc_field), PredicateValue::AuthFunc { .. }) => {
let auth_val = right.resolve(auth)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "any_in".into(),
value: auth_val,
clauses: Vec::new(),
}])
}
(PredicateValue::AuthFunc { .. }, PredicateValue::Field(doc_field)) => {
let auth_val = left.resolve(auth)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "any_in".into(),
value: auth_val,
clauses: Vec::new(),
}])
}
(PredicateValue::AuthRef(left_field), PredicateValue::AuthRef(right_field)) => {
let left_val = auth.resolve_variable(left_field)?;
let right_val = auth.resolve_variable(right_field)?;
let intersects = if let (Some(l), Some(r)) = (left_val.as_array(), right_val.as_array())
{
l.iter().any(|v| r.contains(v))
} else {
false
};
if intersects {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
}
}
(PredicateValue::AuthFunc { .. }, PredicateValue::AuthFunc { .. })
| (PredicateValue::AuthRef(_), PredicateValue::AuthFunc { .. })
| (PredicateValue::AuthFunc { .. }, PredicateValue::AuthRef(_)) => {
let left_val = left.resolve(auth)?;
let right_val = right.resolve(auth)?;
let intersects = if let (Some(l), Some(r)) = (left_val.as_array(), right_val.as_array())
{
l.iter().any(|v| r.contains(v))
} else {
false
};
if intersects {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: serde_json::Value::Null,
clauses: Vec::new(),
}])
}
}
_ => None,
}
}
fn substitute_not(inner: &RlsPredicate, auth: &AuthContext) -> Option<Vec<ScanFilter>> {
match inner {
RlsPredicate::AlwaysTrue => substitute_to_scan_filters(&RlsPredicate::AlwaysFalse, auth),
RlsPredicate::AlwaysFalse => substitute_to_scan_filters(&RlsPredicate::AlwaysTrue, auth),
RlsPredicate::Compare { field, op, value } => {
let negated_op = match op {
CompareOp::Eq => CompareOp::Ne,
CompareOp::Ne => CompareOp::Eq,
CompareOp::Gt => CompareOp::Lte,
CompareOp::Gte => CompareOp::Lt,
CompareOp::Lt => CompareOp::Gte,
CompareOp::Lte => CompareOp::Gt,
CompareOp::In => CompareOp::NotIn,
CompareOp::NotIn => CompareOp::In,
CompareOp::IsNull => CompareOp::IsNotNull,
CompareOp::IsNotNull => CompareOp::IsNull,
_ => return None, };
substitute_to_scan_filters(
&RlsPredicate::Compare {
field: field.clone(),
op: negated_op,
value: value.clone(),
},
auth,
)
}
_ => None, }
}