use super::auth_context::AuthContext;
use super::predicate::{CompareOp, PolicyMode, PredicateValue, RlsPredicate};
use crate::bridge::scan_filter::{FilterOp, ScanFilter};
pub fn substitute_to_scan_filters(
predicate: &RlsPredicate,
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
match predicate {
RlsPredicate::AlwaysTrue => Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}]),
RlsPredicate::AlwaysFalse => {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
}
RlsPredicate::Compare { field, op, value } => {
let resolved = match value {
PredicateValue::Literal(v) => v.clone(),
PredicateValue::AuthRef(auth_field) => auth.resolve_variable(auth_field)?,
PredicateValue::AuthFunc { .. } => value.resolve(auth)?,
PredicateValue::Field(_) => {
return None;
}
};
Some(vec![ScanFilter {
field: field.clone(),
op: op.as_filter_op().into(),
value: nodedb_types::Value::from(resolved),
clauses: Vec::new(),
expr: None,
}])
}
RlsPredicate::Contains { set, element } => substitute_contains(set, element, auth),
RlsPredicate::Intersects { left, right } => substitute_intersects(left, right, auth),
RlsPredicate::And(children) => {
let mut combined = Vec::new();
for child in children {
combined.extend(substitute_to_scan_filters(child, auth)?);
}
Some(combined)
}
RlsPredicate::Or(children) => {
let mut clause_groups: Vec<Vec<ScanFilter>> = Vec::new();
for child in children {
if let Some(filters) = substitute_to_scan_filters(child, auth) {
if filters.len() == 1 && filters[0].op == FilterOp::MatchAll {
return Some(filters);
}
clause_groups.push(filters);
}
}
if clause_groups.is_empty() {
return Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}]);
}
if clause_groups.len() == 1 {
return Some(clause_groups.into_iter().next().unwrap_or_default());
}
Some(vec![ScanFilter {
field: String::new(),
op: "or".into(),
value: nodedb_types::Value::Null,
clauses: clause_groups,
expr: None,
}])
}
RlsPredicate::Not(inner) => substitute_not(inner, auth),
}
}
pub fn combine_policies(
policies: &[(RlsPredicate, PolicyMode)],
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
if policies.is_empty() {
return Some(Vec::new()); }
let mut permissive: Vec<&RlsPredicate> = Vec::new();
let mut restrictive: Vec<&RlsPredicate> = Vec::new();
for (pred, mode) in policies {
match mode {
PolicyMode::Permissive => permissive.push(pred),
PolicyMode::Restrictive => restrictive.push(pred),
}
}
let mut combined = Vec::new();
if permissive.len() == 1 {
combined.extend(substitute_to_scan_filters(permissive[0], auth)?);
} else if permissive.len() > 1 {
let or_children: Vec<RlsPredicate> = permissive.iter().map(|p| (*p).clone()).collect();
let or_pred = RlsPredicate::Or(or_children);
combined.extend(substitute_to_scan_filters(&or_pred, auth)?);
}
for pred in &restrictive {
combined.extend(substitute_to_scan_filters(pred, auth)?);
}
Some(combined)
}
fn substitute_contains(
set: &PredicateValue,
element: &PredicateValue,
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
match (set, element) {
(PredicateValue::AuthRef(auth_field), PredicateValue::Literal(lit)) => {
let auth_val = auth.resolve_variable(auth_field)?;
if let Some(arr) = auth_val.as_array() {
if arr.contains(lit) {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
}
} else {
None }
}
(PredicateValue::AuthFunc { .. }, PredicateValue::Literal(lit)) => {
let auth_val = set.resolve(auth)?;
if let Some(arr) = auth_val.as_array() {
if arr.contains(lit) {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
}
} else {
None }
}
(PredicateValue::Field(doc_field), PredicateValue::AuthRef(auth_field)) => {
let auth_val = auth.resolve_variable(auth_field)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "contains".into(),
value: nodedb_types::Value::from(auth_val),
clauses: Vec::new(),
expr: None,
}])
}
(PredicateValue::Field(doc_field), PredicateValue::AuthFunc { .. }) => {
let auth_val = element.resolve(auth)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "contains".into(),
value: nodedb_types::Value::from(auth_val),
clauses: Vec::new(),
expr: None,
}])
}
(PredicateValue::Field(doc_field), PredicateValue::Literal(lit)) => {
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "contains".into(),
value: nodedb_types::Value::from(lit.clone()),
clauses: Vec::new(),
expr: None,
}])
}
_ => None, }
}
fn substitute_intersects(
left: &PredicateValue,
right: &PredicateValue,
auth: &AuthContext,
) -> Option<Vec<ScanFilter>> {
match (left, right) {
(PredicateValue::Field(doc_field), PredicateValue::AuthRef(auth_field))
| (PredicateValue::AuthRef(auth_field), PredicateValue::Field(doc_field)) => {
let auth_val = auth.resolve_variable(auth_field)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "any_in".into(),
value: nodedb_types::Value::from(auth_val),
clauses: Vec::new(),
expr: None,
}])
}
(PredicateValue::Field(doc_field), PredicateValue::AuthFunc { .. }) => {
let auth_val = right.resolve(auth)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "any_in".into(),
value: nodedb_types::Value::from(auth_val),
clauses: Vec::new(),
expr: None,
}])
}
(PredicateValue::AuthFunc { .. }, PredicateValue::Field(doc_field)) => {
let auth_val = left.resolve(auth)?;
Some(vec![ScanFilter {
field: doc_field.clone(),
op: "any_in".into(),
value: nodedb_types::Value::from(auth_val),
clauses: Vec::new(),
expr: None,
}])
}
(PredicateValue::AuthRef(left_field), PredicateValue::AuthRef(right_field)) => {
let left_val = auth.resolve_variable(left_field)?;
let right_val = auth.resolve_variable(right_field)?;
let intersects = if let (Some(l), Some(r)) = (left_val.as_array(), right_val.as_array())
{
l.iter().any(|v| r.contains(v))
} else {
false
};
if intersects {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
}
}
(PredicateValue::AuthFunc { .. }, PredicateValue::AuthFunc { .. })
| (PredicateValue::AuthRef(_), PredicateValue::AuthFunc { .. })
| (PredicateValue::AuthFunc { .. }, PredicateValue::AuthRef(_)) => {
let left_val = left.resolve(auth)?;
let right_val = right.resolve(auth)?;
let intersects = if let (Some(l), Some(r)) = (left_val.as_array(), right_val.as_array())
{
l.iter().any(|v| r.contains(v))
} else {
false
};
if intersects {
Some(vec![ScanFilter {
field: String::new(),
op: "match_all".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
} else {
Some(vec![ScanFilter {
field: "__rls_deny__".into(),
op: "is_not_null".into(),
value: nodedb_types::Value::Null,
clauses: Vec::new(),
expr: None,
}])
}
}
_ => None,
}
}
fn substitute_not(inner: &RlsPredicate, auth: &AuthContext) -> Option<Vec<ScanFilter>> {
match inner {
RlsPredicate::AlwaysTrue => substitute_to_scan_filters(&RlsPredicate::AlwaysFalse, auth),
RlsPredicate::AlwaysFalse => substitute_to_scan_filters(&RlsPredicate::AlwaysTrue, auth),
RlsPredicate::Compare { field, op, value } => {
let negated_op = match op {
CompareOp::Eq => CompareOp::Ne,
CompareOp::Ne => CompareOp::Eq,
CompareOp::Gt => CompareOp::Lte,
CompareOp::Gte => CompareOp::Lt,
CompareOp::Lt => CompareOp::Gte,
CompareOp::Lte => CompareOp::Gt,
CompareOp::In => CompareOp::NotIn,
CompareOp::NotIn => CompareOp::In,
CompareOp::IsNull => CompareOp::IsNotNull,
CompareOp::IsNotNull => CompareOp::IsNull,
_ => return None, };
substitute_to_scan_filters(
&RlsPredicate::Compare {
field: field.clone(),
op: negated_op,
value: value.clone(),
},
auth,
)
}
_ => None, }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::control::security::auth_context::AuthContext;
use crate::control::security::identity::{
AuthMethod, AuthenticatedIdentity, DatabaseSet, Role,
};
use crate::control::security::predicate::{
CompareOp, PolicyMode, PredicateValue, RlsPredicate,
};
use crate::types::TenantId;
use nodedb_types::id::DatabaseId;
fn test_identity() -> AuthenticatedIdentity {
AuthenticatedIdentity {
user_id: 42,
username: "alice".into(),
tenant_id: TenantId::new(1),
auth_method: AuthMethod::ScramSha256,
roles: vec![Role::ReadWrite],
is_superuser: false,
default_database: None,
accessible_databases: DatabaseSet::Some(smallvec::smallvec![DatabaseId::DEFAULT]),
}
}
fn auth_with_database(db_id: DatabaseId) -> AuthContext {
let mut ctx = AuthContext::from_identity(&test_identity(), "s_test".into());
ctx.database_id = Some(db_id);
ctx
}
fn auth_without_database() -> AuthContext {
AuthContext::from_identity(&test_identity(), "s_test".into())
}
#[test]
fn database_id_auth_ref_substitutes_correctly() {
let db_id = DatabaseId::new(99);
let auth = auth_with_database(db_id);
let predicate = RlsPredicate::Compare {
field: "owning_db".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("database_id".into()),
};
let filters = substitute_to_scan_filters(&predicate, &auth)
.expect("should resolve when database_id is set");
assert_eq!(filters.len(), 1);
assert_eq!(filters[0].field, "owning_db");
match &filters[0].value {
nodedb_types::Value::Integer(n) => assert_eq!(*n as u64, db_id.as_u64()),
other => panic!("expected numeric value, got {:?}", other),
}
}
#[test]
fn database_id_auth_ref_fails_closed_when_none() {
let auth = auth_without_database();
let predicate = RlsPredicate::Compare {
field: "owning_db".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("database_id".into()),
};
let result = substitute_to_scan_filters(&predicate, &auth);
assert!(
result.is_none(),
"predicate must fail closed when database_id is None"
);
}
#[test]
fn combine_database_id_policy_passes_when_set() {
let db_id = DatabaseId::new(77);
let auth = auth_with_database(db_id);
let predicate = RlsPredicate::Compare {
field: "db".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("database_id".into()),
};
let policies = [(predicate, PolicyMode::Permissive)];
let filters = combine_policies(&policies, &auth);
assert!(
filters.as_ref().is_some_and(|v| !v.is_empty()),
"should produce scan filters when database_id is bound"
);
}
}