use super::predicate::{CompareOp, PredicateValue, RlsPredicate};
pub fn parse_predicate(input: &str) -> Result<RlsPredicate, PredicateParseError> {
let tokens = tokenize(input)?;
let mut pos = 0;
let result = parse_or_expr(&tokens, &mut pos)?;
if pos < tokens.len() {
return Err(PredicateParseError::UnexpectedToken {
token: tokens[pos].clone(),
position: pos,
});
}
Ok(result)
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum PredicateParseError {
#[error("unexpected token '{token}' at position {position}")]
UnexpectedToken { token: String, position: usize },
#[error("unexpected end of expression")]
UnexpectedEnd,
#[error("unknown operator: '{0}'")]
UnknownOperator(String),
#[error("invalid literal: '{0}'")]
InvalidLiteral(String),
#[error("unmatched parenthesis")]
UnmatchedParen,
}
fn tokenize(input: &str) -> Result<Vec<String>, PredicateParseError> {
let mut tokens = Vec::new();
let mut chars = input.chars().peekable();
while let Some(&ch) = chars.peek() {
if ch.is_whitespace() {
chars.next();
continue;
}
if ch == '(' || ch == ')' {
tokens.push(ch.to_string());
chars.next();
continue;
}
if ch == '\'' {
chars.next(); let mut s = String::new();
loop {
match chars.next() {
Some('\'') => {
if chars.peek() == Some(&'\'') {
s.push('\'');
chars.next();
} else {
break;
}
}
Some(c) => s.push(c),
None => {
return Err(PredicateParseError::InvalidLiteral(
"unterminated string".into(),
));
}
}
}
tokens.push(format!("'{s}'"));
continue;
}
if matches!(ch, '=' | '!' | '<' | '>') {
let mut op = String::new();
op.push(ch);
chars.next();
if let Some(&next) = chars.peek()
&& ((ch == '!' && next == '=')
|| (ch == '<' && matches!(next, '=' | '>'))
|| (ch == '>' && next == '='))
{
op.push(next);
chars.next();
}
tokens.push(op);
continue;
}
if ch.is_alphanumeric() || ch == '_' || ch == '$' || ch == '-' {
let mut word = String::new();
while let Some(&c) = chars.peek() {
if c.is_alphanumeric() || c == '_' || c == '.' || c == '$' || c == '-' {
word.push(c);
chars.next();
} else {
break;
}
}
tokens.push(word);
continue;
}
chars.next();
}
Ok(tokens)
}
fn parse_or_expr(tokens: &[String], pos: &mut usize) -> Result<RlsPredicate, PredicateParseError> {
let mut left = parse_and_expr(tokens, pos)?;
while *pos < tokens.len() && tokens[*pos].to_uppercase() == "OR" {
*pos += 1;
let right = parse_and_expr(tokens, pos)?;
left = match left {
RlsPredicate::Or(mut children) => {
children.push(right);
RlsPredicate::Or(children)
}
_ => RlsPredicate::Or(vec![left, right]),
};
}
Ok(left)
}
fn parse_and_expr(tokens: &[String], pos: &mut usize) -> Result<RlsPredicate, PredicateParseError> {
let mut left = parse_atom(tokens, pos)?;
while *pos < tokens.len() && tokens[*pos].to_uppercase() == "AND" {
*pos += 1;
let right = parse_atom(tokens, pos)?;
left = match left {
RlsPredicate::And(mut children) => {
children.push(right);
RlsPredicate::And(children)
}
_ => RlsPredicate::And(vec![left, right]),
};
}
Ok(left)
}
fn parse_atom(tokens: &[String], pos: &mut usize) -> Result<RlsPredicate, PredicateParseError> {
if *pos >= tokens.len() {
return Err(PredicateParseError::UnexpectedEnd);
}
if tokens[*pos].to_uppercase() == "NOT" {
*pos += 1;
let inner = parse_atom(tokens, pos)?;
return Ok(RlsPredicate::Not(Box::new(inner)));
}
if tokens[*pos] == "(" {
*pos += 1;
let inner = parse_or_expr(tokens, pos)?;
if *pos >= tokens.len() || tokens[*pos] != ")" {
return Err(PredicateParseError::UnmatchedParen);
}
*pos += 1;
return Ok(inner);
}
let left_value = parse_value_ref(tokens, pos)?;
if *pos >= tokens.len() {
return match left_value {
PredicateValue::AuthRef(_)
| PredicateValue::Field(_)
| PredicateValue::AuthFunc { .. } => Ok(RlsPredicate::Compare {
field: match &left_value {
PredicateValue::Field(f) => f.clone(),
_ => String::new(),
},
op: CompareOp::IsNotNull,
value: left_value,
}),
PredicateValue::Literal(_) => Ok(RlsPredicate::AlwaysTrue),
};
}
let op_token = tokens[*pos].to_uppercase();
match op_token.as_str() {
"CONTAINS" => {
*pos += 1;
let element = parse_value_ref(tokens, pos)?;
Ok(RlsPredicate::Contains {
set: left_value,
element,
})
}
"INTERSECTS" => {
*pos += 1;
let right = parse_value_ref(tokens, pos)?;
Ok(RlsPredicate::Intersects {
left: left_value,
right,
})
}
_ => {
let compare_op = parse_compare_op(tokens, pos)?;
let right_value = parse_value_ref(tokens, pos)?;
match (&left_value, &right_value) {
(PredicateValue::Field(f), _) => Ok(RlsPredicate::Compare {
field: f.clone(),
op: compare_op,
value: right_value,
}),
(_, PredicateValue::Field(f)) => {
let flipped = match compare_op {
CompareOp::Gt => CompareOp::Lt,
CompareOp::Gte => CompareOp::Lte,
CompareOp::Lt => CompareOp::Gt,
CompareOp::Lte => CompareOp::Gte,
other => other, };
Ok(RlsPredicate::Compare {
field: f.clone(),
op: flipped,
value: left_value,
})
}
_ => Ok(RlsPredicate::Compare {
field: String::new(), op: compare_op,
value: right_value,
}),
}
}
}
}
fn parse_value_ref(
tokens: &[String],
pos: &mut usize,
) -> Result<PredicateValue, PredicateParseError> {
if *pos >= tokens.len() {
return Err(PredicateParseError::UnexpectedEnd);
}
let token = &tokens[*pos];
if let Some(field) = token.strip_prefix("$auth.") {
*pos += 1;
return Ok(PredicateValue::AuthRef(field.to_string()));
}
if token.starts_with('\'') && token.ends_with('\'') && token.len() >= 2 {
*pos += 1;
let inner = &token[1..token.len() - 1];
return Ok(PredicateValue::Literal(serde_json::Value::String(
inner.to_string(),
)));
}
if token.starts_with(|c: char| c.is_ascii_digit() || c == '-') {
if let Ok(n) = token.parse::<i64>() {
*pos += 1;
return Ok(PredicateValue::Literal(serde_json::json!(n)));
}
if let Ok(f) = token.parse::<f64>() {
*pos += 1;
return Ok(PredicateValue::Literal(serde_json::json!(f)));
}
}
match token.to_lowercase().as_str() {
"true" => {
*pos += 1;
return Ok(PredicateValue::Literal(serde_json::json!(true)));
}
"false" => {
*pos += 1;
return Ok(PredicateValue::Literal(serde_json::json!(false)));
}
"null" => {
*pos += 1;
return Ok(PredicateValue::Literal(serde_json::Value::Null));
}
_ => {}
}
let upper = token.to_uppercase();
if !matches!(
upper.as_str(),
"AND" | "OR" | "NOT" | "CONTAINS" | "INTERSECTS"
) {
*pos += 1;
return Ok(PredicateValue::Field(token.clone()));
}
Err(PredicateParseError::UnexpectedToken {
token: token.clone(),
position: *pos,
})
}
fn parse_compare_op(tokens: &[String], pos: &mut usize) -> Result<CompareOp, PredicateParseError> {
if *pos >= tokens.len() {
return Err(PredicateParseError::UnexpectedEnd);
}
let token = &tokens[*pos];
if let Some(op) = CompareOp::from_str_sql(token) {
*pos += 1;
return Ok(op);
}
Err(PredicateParseError::UnknownOperator(token.clone()))
}
pub fn validate_auth_refs(predicate: &RlsPredicate) -> Result<(), String> {
let known_fields = [
"id",
"username",
"email",
"tenant_id",
"org_id",
"org_ids",
"roles",
"groups",
"permissions",
"status",
"auth_method",
"auth_time",
"session_id",
];
fn check(pred: &RlsPredicate, known: &[&str]) -> Result<(), String> {
match pred {
RlsPredicate::Compare { value, .. } => check_value(value, known),
RlsPredicate::Contains { set, element } => {
check_value(set, known)?;
check_value(element, known)
}
RlsPredicate::Intersects { left, right } => {
check_value(left, known)?;
check_value(right, known)
}
RlsPredicate::And(children) | RlsPredicate::Or(children) => {
for child in children {
check(child, known)?;
}
Ok(())
}
RlsPredicate::Not(inner) => check(inner, known),
RlsPredicate::AlwaysTrue | RlsPredicate::AlwaysFalse => Ok(()),
}
}
fn check_value(val: &PredicateValue, known: &[&str]) -> Result<(), String> {
match val {
PredicateValue::AuthRef(field) => {
let base = field.split('.').next().unwrap_or(field);
if base == "metadata" {
return Ok(());
}
if !known.contains(&base) {
return Err(format!(
"unknown $auth field: '{field}'. Valid fields: {}",
known.join(", ")
));
}
}
PredicateValue::AuthFunc { func, .. } => {
const KNOWN_FUNCS: &[&str] = &[
"scope_status",
"scope_expires_at",
"quota_remaining",
"quota_pct",
];
if !KNOWN_FUNCS.contains(&func.as_str()) {
return Err(format!(
"unknown $auth function: '{func}'. Valid functions: {}",
KNOWN_FUNCS.join(", ")
));
}
}
_ => {}
}
Ok(())
}
check(predicate, &known_fields)
}