use serde::{Deserialize, Serialize};
use super::auth_context::AuthContext;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RlsPredicate {
Compare {
field: String,
op: CompareOp,
value: PredicateValue,
},
Contains {
set: PredicateValue,
element: PredicateValue,
},
Intersects {
left: PredicateValue,
right: PredicateValue,
},
And(Vec<RlsPredicate>),
Or(Vec<RlsPredicate>),
Not(Box<RlsPredicate>),
AlwaysTrue,
AlwaysFalse,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompareOp {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
In,
NotIn,
Like,
ILike,
IsNull,
IsNotNull,
}
impl CompareOp {
pub fn as_filter_op(&self) -> &'static str {
match self {
Self::Eq => "eq",
Self::Ne => "ne",
Self::Gt => "gt",
Self::Gte => "gte",
Self::Lt => "lt",
Self::Lte => "lte",
Self::In => "in",
Self::NotIn => "not_in",
Self::Like => "like",
Self::ILike => "ilike",
Self::IsNull => "is_null",
Self::IsNotNull => "is_not_null",
}
}
pub fn from_str_sql(s: &str) -> Option<Self> {
match s {
"=" => Some(Self::Eq),
"!=" | "<>" => Some(Self::Ne),
">" => Some(Self::Gt),
">=" => Some(Self::Gte),
"<" => Some(Self::Lt),
"<=" => Some(Self::Lte),
_ => {
let upper = s.to_uppercase();
match upper.as_str() {
"IN" => Some(Self::In),
"NOT_IN" | "NOT IN" => Some(Self::NotIn),
"LIKE" => Some(Self::Like),
"ILIKE" => Some(Self::ILike),
"IS_NULL" | "IS NULL" => Some(Self::IsNull),
"IS_NOT_NULL" | "IS NOT NULL" => Some(Self::IsNotNull),
_ => None,
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PredicateValue {
Literal(serde_json::Value),
Field(String),
AuthRef(String),
AuthFunc { func: String, args: Vec<String> },
}
impl PredicateValue {
pub fn is_auth_ref(&self) -> bool {
matches!(self, Self::AuthRef(_) | Self::AuthFunc { .. })
}
pub fn resolve(&self, auth: &AuthContext) -> Option<serde_json::Value> {
match self {
Self::Literal(v) => Some(v.clone()),
Self::Field(_) => None,
Self::AuthRef(field) => auth.resolve_variable(field),
Self::AuthFunc { func, args } => {
let arg = args.first().map(|s| s.as_str()).unwrap_or("");
let key = format!("{func}.{arg}");
auth.resolve_variable(&format!("metadata.{key}"))
}
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum PolicyMode {
#[default]
Permissive,
Restrictive,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::control::security::auth_context::AuthContext;
use crate::control::security::identity::{AuthMethod, AuthenticatedIdentity, Role};
use crate::control::security::predicate_eval::{combine_policies, substitute_to_scan_filters};
use crate::control::security::predicate_parser::{parse_predicate, validate_auth_refs};
use crate::types::TenantId;
fn make_auth() -> AuthContext {
let identity = AuthenticatedIdentity {
user_id: 123,
username: "alice".into(),
tenant_id: TenantId::new(1),
roles: vec![Role::ReadWrite],
auth_method: AuthMethod::ApiKey,
is_superuser: false,
};
AuthContext::from_identity(&identity, "test-session".into())
}
#[test]
fn simple_equality_substitution() {
let pred = RlsPredicate::Compare {
field: "user_id".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("id".into()),
};
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters.len(), 1);
assert_eq!(filters[0].field, "user_id");
assert_eq!(filters[0].op, "eq");
assert_eq!(filters[0].value, serde_json::json!("123"));
}
#[test]
fn literal_comparison() {
let pred = RlsPredicate::Compare {
field: "status".into(),
op: CompareOp::Eq,
value: PredicateValue::Literal(serde_json::json!("active")),
};
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters[0].value, serde_json::json!("active"));
}
#[test]
fn auth_contains_check() {
let pred = RlsPredicate::Contains {
set: PredicateValue::AuthRef("roles".into()),
element: PredicateValue::Literal(serde_json::json!("readwrite")),
};
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters[0].op, "match_all");
}
#[test]
fn field_contains_auth_ref() {
let pred = RlsPredicate::Contains {
set: PredicateValue::Field("allowed_users".into()),
element: PredicateValue::AuthRef("id".into()),
};
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters[0].field, "allowed_users");
assert_eq!(filters[0].op, "contains");
assert_eq!(filters[0].value, serde_json::json!("123"));
}
#[test]
fn and_combination() {
let pred = RlsPredicate::And(vec![
RlsPredicate::Compare {
field: "tenant_id".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("tenant_id".into()),
},
RlsPredicate::Compare {
field: "status".into(),
op: CompareOp::Eq,
value: PredicateValue::Literal(serde_json::json!("published")),
},
]);
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters.len(), 2);
}
#[test]
fn or_combination() {
let pred = RlsPredicate::Or(vec![
RlsPredicate::Compare {
field: "owner".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("id".into()),
},
RlsPredicate::Compare {
field: "visibility".into(),
op: CompareOp::Eq,
value: PredicateValue::Literal(serde_json::json!("public")),
},
]);
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters.len(), 1);
assert_eq!(filters[0].op, "or");
assert_eq!(filters[0].clauses.len(), 2);
}
#[test]
fn missing_auth_ref_denies() {
let pred = RlsPredicate::Compare {
field: "org_id".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("nonexistent_field".into()),
};
let auth = make_auth();
assert!(substitute_to_scan_filters(&pred, &auth).is_none());
}
#[test]
fn policy_combination_permissive_and_restrictive() {
let policies = vec![
(
RlsPredicate::Compare {
field: "user_id".into(),
op: CompareOp::Eq,
value: PredicateValue::AuthRef("id".into()),
},
PolicyMode::Permissive,
),
(
RlsPredicate::Compare {
field: "status".into(),
op: CompareOp::Eq,
value: PredicateValue::Literal(serde_json::json!("active")),
},
PolicyMode::Restrictive,
),
];
let auth = make_auth();
let filters = combine_policies(&policies, &auth).unwrap();
assert_eq!(filters.len(), 2);
}
#[test]
fn empty_policies_allow_all() {
let auth = make_auth();
let filters = combine_policies(&[], &auth).unwrap();
assert!(filters.is_empty());
}
#[test]
fn parse_simple_equality() {
let pred = parse_predicate("user_id = $auth.id").unwrap();
match pred {
RlsPredicate::Compare { field, op, value } => {
assert_eq!(field, "user_id");
assert_eq!(op, CompareOp::Eq);
assert!(matches!(value, PredicateValue::AuthRef(ref f) if f == "id"));
}
_ => panic!("expected Compare"),
}
}
#[test]
fn parse_and_or() {
let pred =
parse_predicate("user_id = $auth.id AND status = 'active' OR visibility = 'public'")
.unwrap();
assert!(matches!(pred, RlsPredicate::Or(_)));
}
#[test]
fn parse_parenthesized() {
let pred =
parse_predicate("(user_id = $auth.id OR shared = true) AND status = 'active'").unwrap();
assert!(matches!(pred, RlsPredicate::And(_)));
}
#[test]
fn parse_contains() {
let pred = parse_predicate("$auth.roles CONTAINS 'admin'").unwrap();
assert!(matches!(pred, RlsPredicate::Contains { .. }));
}
#[test]
fn parse_intersects() {
let pred = parse_predicate("allowed_groups INTERSECTS $auth.groups").unwrap();
assert!(matches!(pred, RlsPredicate::Intersects { .. }));
}
#[test]
fn parse_not() {
let pred = parse_predicate("NOT status = 'archived'").unwrap();
assert!(matches!(pred, RlsPredicate::Not(_)));
}
#[test]
fn validate_auth_refs_passes() {
let pred = parse_predicate("user_id = $auth.id AND $auth.roles CONTAINS 'admin'").unwrap();
assert!(validate_auth_refs(&pred).is_ok());
}
#[test]
fn validate_auth_refs_rejects_unknown() {
let pred = parse_predicate("user_id = $auth.foobar").unwrap();
assert!(validate_auth_refs(&pred).is_err());
}
#[test]
fn validate_auth_refs_allows_metadata() {
let pred = parse_predicate("plan = $auth.metadata.plan").unwrap();
assert!(validate_auth_refs(&pred).is_ok());
}
#[test]
fn not_negation_produces_correct_op() {
let pred = RlsPredicate::Not(Box::new(RlsPredicate::Compare {
field: "active".into(),
op: CompareOp::Eq,
value: PredicateValue::Literal(serde_json::json!(true)),
}));
let auth = make_auth();
let filters = substitute_to_scan_filters(&pred, &auth).unwrap();
assert_eq!(filters[0].op, "ne");
}
}