#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SqlAuthClassification {
Authorized,
}
pub fn classify_sql_query(sql: &str, acl_tables: &[String]) -> Option<SqlAuthClassification> {
let normalized = normalize_sql(sql);
if !normalized.trim_start().starts_with("select") {
return None;
}
if matches_join_through_acl(&normalized, acl_tables) {
return Some(SqlAuthClassification::Authorized);
}
if matches_direct_user_id_predicate(&normalized) {
return Some(SqlAuthClassification::Authorized);
}
None
}
fn matches_join_through_acl(sql: &str, acl_tables: &[String]) -> bool {
let Some(where_idx) = sql.find(" where ") else {
return false;
};
let from_to_where = &sql[..where_idx];
let where_clause = &sql[where_idx + " where ".len()..];
let has_acl_join = acl_tables.iter().any(|t| {
let lower = t.to_ascii_lowercase();
from_to_where.contains(&format!(" join {} ", lower))
|| from_to_where.ends_with(&format!(" join {}", lower))
|| from_to_where.contains(&format!(" inner join {} ", lower))
|| from_to_where.contains(&format!(" left join {} ", lower))
|| from_to_where.contains(&format!(" right join {} ", lower))
});
if !has_acl_join {
return false;
}
where_clause_contains_user_id_bind(where_clause)
}
fn matches_direct_user_id_predicate(sql: &str) -> bool {
let Some(where_idx) = sql.find(" where ") else {
return false;
};
let from_to_where = &sql[..where_idx];
if from_to_where.contains(" join ") {
return false;
}
let where_clause = &sql[where_idx + " where ".len()..];
where_clause_contains_user_id_bind(where_clause)
}
fn where_clause_contains_user_id_bind(where_clause: &str) -> bool {
let where_only = strip_trailing_clauses(where_clause);
let needles = ["user_id", "userid"];
for needle in needles {
for (idx, _) in where_only.match_indices(needle) {
let before = where_only[..idx].chars().last();
if !is_column_boundary_left(before) {
continue;
}
let rest = &where_only[idx + needle.len()..];
let rest = rest.trim_start();
if !rest.starts_with('=') {
continue;
}
let after_eq = rest[1..].trim_start();
if looks_like_bind_param(after_eq) {
return true;
}
}
}
false
}
fn is_column_boundary_left(ch: Option<char>) -> bool {
match ch {
None => true,
Some(c) => matches!(c, ' ' | '\t' | '(' | '.' | ',' | '\n' | '\r'),
}
}
fn looks_like_bind_param(after_eq: &str) -> bool {
let bytes = after_eq.as_bytes();
if bytes.is_empty() {
return false;
}
match bytes[0] {
b'?' => true,
b'$' => bytes.get(1).is_some_and(|b| b.is_ascii_digit()),
b':' => bytes
.get(1)
.is_some_and(|b| b.is_ascii_alphabetic() || *b == b'_'),
_ => false,
}
}
fn strip_trailing_clauses(where_clause: &str) -> &str {
let candidates = [" order by ", " limit ", " group by ", " having "];
let mut end = where_clause.len();
for cand in candidates {
if let Some(idx) = where_clause.find(cand) {
end = end.min(idx);
}
}
&where_clause[..end]
}
fn normalize_sql(sql: &str) -> String {
let mut out = String::with_capacity(sql.len());
let mut prev_space = true;
out.push(' ');
for ch in sql.chars() {
if ch.is_whitespace() {
if !prev_space {
out.push(' ');
prev_space = true;
}
} else {
out.push(ch.to_ascii_lowercase());
prev_space = false;
}
}
if !out.ends_with(' ') {
out.push(' ');
}
out
}
#[cfg(test)]
mod tests {
use super::{SqlAuthClassification, classify_sql_query};
fn acl() -> Vec<String> {
vec![
"group_members".into(),
"org_memberships".into(),
"workspace_members".into(),
"tenant_members".into(),
"members".into(),
"share_grants".into(),
]
}
#[test]
fn join_through_group_members_with_user_bind_is_authorized() {
let sql = "SELECT d.id, d.group_id, d.title \
FROM docs d \
JOIN group_members gm ON gm.group_id = d.group_id \
WHERE gm.user_id = ?1 \
ORDER BY d.updated_at DESC";
assert_eq!(
classify_sql_query(sql, &acl()),
Some(SqlAuthClassification::Authorized)
);
}
#[test]
fn join_through_workspace_members_with_postgres_bind() {
let sql = "SELECT t.* \
FROM tickets t \
INNER JOIN workspace_members wm ON wm.workspace_id = t.workspace_id \
WHERE wm.user_id = $1";
assert_eq!(
classify_sql_query(sql, &acl()),
Some(SqlAuthClassification::Authorized)
);
}
#[test]
fn direct_user_id_predicate_is_authorized() {
let sql = "SELECT id, name FROM peers WHERE user_id = ?1";
assert_eq!(
classify_sql_query(sql, &acl()),
Some(SqlAuthClassification::Authorized)
);
}
#[test]
fn direct_id_and_user_id_predicate_is_authorized() {
let sql = "SELECT title FROM docs WHERE id = ?1 AND user_id = ?2";
assert_eq!(
classify_sql_query(sql, &acl()),
Some(SqlAuthClassification::Authorized)
);
}
#[test]
fn named_bind_is_authorized() {
let sql = "SELECT * FROM peers WHERE user_id = :uid";
assert_eq!(
classify_sql_query(sql, &acl()),
Some(SqlAuthClassification::Authorized)
);
}
#[test]
fn join_against_non_acl_table_is_not_authorized() {
let sql = "SELECT d.* FROM docs d \
JOIN audit_log al ON al.doc_id = d.id \
WHERE al.user_id = ?1";
assert_eq!(classify_sql_query(sql, &acl()), None);
}
#[test]
fn select_without_user_id_predicate_is_not_authorized() {
let sql = "SELECT * FROM docs WHERE id = ?1";
assert_eq!(classify_sql_query(sql, &acl()), None);
}
#[test]
fn non_select_query_is_not_authorized() {
let sql = "DELETE FROM docs WHERE user_id = ?1";
assert_eq!(classify_sql_query(sql, &acl()), None);
}
#[test]
fn similar_column_names_do_not_trip_user_id_match() {
let sql = "SELECT * FROM posts WHERE posted_user_id = ?1";
assert_eq!(classify_sql_query(sql, &acl()), None);
}
#[test]
fn order_by_after_user_id_is_handled() {
let sql = "SELECT * FROM peers WHERE user_id = ?1 ORDER BY created_at DESC LIMIT 50";
assert_eq!(
classify_sql_query(sql, &acl()),
Some(SqlAuthClassification::Authorized)
);
}
#[test]
fn empty_acl_list_disables_join_pattern_but_keeps_direct() {
let join_sql = "SELECT * FROM docs d \
JOIN group_members gm ON gm.group_id = d.group_id \
WHERE gm.user_id = ?1";
let direct_sql = "SELECT * FROM peers WHERE user_id = ?1";
let empty: Vec<String> = Vec::new();
assert_eq!(classify_sql_query(join_sql, &empty), None);
assert_eq!(
classify_sql_query(direct_sql, &empty),
Some(SqlAuthClassification::Authorized)
);
}
}