Skip to main content

qail_core/migrate/
policy_parser.rs

1//! Policy SQL expression parser.
2//!
3//! Converts raw SQL policy expressions (from `pg_policies.qual` / `.with_check`)
4//! into typed `Expr` AST nodes that can be rendered back to QAIL format.
5//!
6//! This lives in `qail-core` so all downstream crates (`qail-pg`, CLI, etc.)
7//! can reuse it.
8
9use crate::ast::expr::Expr;
10use crate::ast::{BinaryOp, Value};
11
12/// Parse a raw SQL policy expression from `pg_policies` into a typed `Expr` AST.
13///
14/// Handles the common RLS patterns:
15/// - `col = current_setting('var', true)::type`  (tenant check)
16/// - `(current_setting('var', true))::boolean = true`  (session bool check)
17/// - `expr1 OR expr2` / `expr1 AND expr2` (combinators)
18///
19/// Returns an error for unsupported expressions.
20pub fn parse_policy_expr(sql: &str) -> Result<Expr, String> {
21    let s = sql.trim();
22
23    // Strip outer parens if the entire expression is wrapped
24    let s = strip_outer_parens(s);
25
26    // Try: expr OR expr
27    if let Some(pos) = find_top_level_op(s, " OR ") {
28        let left = parse_policy_expr(&s[..pos])?;
29        let right = parse_policy_expr(&s[pos + 4..])?;
30        return Ok(Expr::Binary {
31            left: Box::new(left),
32            op: BinaryOp::Or,
33            right: Box::new(right),
34            alias: None,
35        });
36    }
37
38    // Try: expr AND expr
39    if let Some(pos) = find_top_level_op(s, " AND ") {
40        let left = parse_policy_expr(&s[..pos])?;
41        let right = parse_policy_expr(&s[pos + 5..])?;
42        return Ok(Expr::Binary {
43            left: Box::new(left),
44            op: BinaryOp::And,
45            right: Box::new(right),
46            alias: None,
47        });
48    }
49
50    // Try: col = current_setting('var', true)::type
51    // or:  (current_setting('var', true))::type = 'true'
52    if let Some(eq_pos) = find_top_level_op(s, " = ") {
53        let lhs = s[..eq_pos].trim();
54        let rhs = s[eq_pos + 3..].trim();
55
56        // Pattern 1: col = current_setting(...)::type
57        if let Some(expr) = try_parse_tenant_check(lhs, rhs) {
58            return Ok(expr);
59        }
60        // Pattern 2: current_setting(...)::type = value (swapped)
61        if let Some(expr) = try_parse_tenant_check(rhs, lhs) {
62            // Swap back so col is on the left
63            return Ok(expr);
64        }
65    }
66
67    Err(format!("unsupported policy expression: {}", s))
68}
69
70/// Try to parse `lhs = rhs` where lhs is a column name and rhs is `current_setting('var', ...)::type`
71fn try_parse_tenant_check(col_side: &str, setting_side: &str) -> Option<Expr> {
72    let (session_var, cast_type) = parse_setting_expr(setting_side)?;
73    let left = parse_policy_lhs(col_side);
74
75    Some(Expr::Binary {
76        left: Box::new(left),
77        op: BinaryOp::Eq,
78        right: Box::new(Expr::Cast {
79            expr: Box::new(Expr::FunctionCall {
80                name: "current_setting".into(),
81                args: vec![Expr::Literal(Value::String(session_var))],
82                alias: None,
83            }),
84            target_type: cast_type,
85            alias: None,
86        }),
87        alias: None,
88    })
89}
90
91fn parse_policy_lhs(col_side: &str) -> Expr {
92    let lhs = strip_outer_parens(col_side).trim();
93    if is_sql_true_literal(lhs) {
94        return Expr::Literal(Value::Bool(true));
95    }
96    if is_sql_false_literal(lhs) {
97        return Expr::Literal(Value::Bool(false));
98    }
99    Expr::Named(lhs.to_string())
100}
101
102fn parse_setting_expr(setting_side: &str) -> Option<(String, String)> {
103    let mut normalized = strip_outer_parens(setting_side).trim().to_string();
104    // Normalize pg_dump-style wrappers like: (NULLIF(...))::uuid -> NULLIF(...)::uuid
105    loop {
106        let candidate = normalized.trim();
107        if !candidate.starts_with('(') {
108            break;
109        }
110        let Some(close_idx) = find_matching_paren(candidate, 0) else {
111            break;
112        };
113        let rest = candidate[close_idx + 1..].trim();
114        if !rest.starts_with("::") {
115            break;
116        }
117        let inner = candidate[1..close_idx].trim();
118        normalized = format!("{inner}{rest}");
119    }
120    let s = normalized.trim();
121
122    // Direct: current_setting('app.current_tenant_id', true)::uuid
123    if let Some((session_var, rest)) = parse_current_setting_call(s) {
124        let cast = parse_cast_suffix(rest).unwrap_or_else(|| "text".to_string());
125        return Some((session_var, cast));
126    }
127
128    // Wrapped: NULLIF(current_setting(...), ''::text)::uuid
129    if let Some((args, rest)) = parse_function_args_and_rest_ci(s, "NULLIF") {
130        let (arg1, _arg2) = split_args2(args)?;
131        let (session_var, mut cast) = parse_setting_expr(arg1.trim())?;
132        if let Some(parsed_cast) = parse_cast_suffix(rest) {
133            cast = parsed_cast;
134        }
135        return Some((session_var, cast));
136    }
137
138    // Wrapped: COALESCE(current_setting(...), 'false'::text)
139    if let Some((args, rest)) = parse_function_args_and_rest_ci(s, "COALESCE") {
140        let (arg1, arg2) = split_args2(args)?;
141        let (session_var, mut cast) = parse_setting_expr(arg1.trim())?;
142        if let Some(parsed_cast) = parse_cast_suffix(rest) {
143            cast = parsed_cast;
144        } else if is_sql_bool_string_literal(arg2.trim()) {
145            // COALESCE(..., 'false'::text) is used for boolean context in pg_dump output
146            cast = "boolean".to_string();
147        }
148        return Some((session_var, cast));
149    }
150
151    None
152}
153
154fn parse_cast_suffix(rest: &str) -> Option<String> {
155    let tail = strip_outer_parens(rest).trim();
156    tail.strip_prefix("::").map(|s| s.trim().to_string())
157}
158
159fn split_args2(args: &str) -> Option<(&str, &str)> {
160    let idx = find_top_level_char(args, ',')?;
161    Some((&args[..idx], &args[idx + 1..]))
162}
163
164fn parse_function_args_and_rest_ci<'a>(s: &'a str, fn_name: &str) -> Option<(&'a str, &'a str)> {
165    let s = s.trim();
166    let prefix = format!("{fn_name}(");
167    if !starts_with_ci(s, &prefix) {
168        return None;
169    }
170    let open_idx = fn_name.len();
171    let close_idx = find_matching_paren(s, open_idx)?;
172    let args = &s[open_idx + 1..close_idx];
173    let rest = &s[close_idx + 1..];
174    Some((args, rest))
175}
176
177fn parse_current_setting_call(expr: &str) -> Option<(String, &str)> {
178    let (args, rest) = parse_function_args_and_rest_ci(expr, "current_setting")?;
179    let session_var = extract_first_string_literal(args)?;
180    Some((session_var, rest))
181}
182
183fn starts_with_ci(s: &str, prefix: &str) -> bool {
184    s.get(..prefix.len())
185        .is_some_and(|h| h.eq_ignore_ascii_case(prefix))
186}
187
188fn find_matching_paren(s: &str, open_idx: usize) -> Option<usize> {
189    let bytes = s.as_bytes();
190    if *bytes.get(open_idx)? != b'(' {
191        return None;
192    }
193    let mut depth = 0usize;
194    let mut i = open_idx;
195    let mut in_string = false;
196    while i < bytes.len() {
197        let b = bytes[i];
198        if in_string {
199            if b == b'\'' {
200                // SQL escaped quote: ''
201                if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
202                    i += 2;
203                    continue;
204                }
205                in_string = false;
206            }
207            i += 1;
208            continue;
209        }
210        match b {
211            b'\'' => in_string = true,
212            b'(' => depth += 1,
213            b')' => {
214                depth = depth.saturating_sub(1);
215                if depth == 0 {
216                    return Some(i);
217                }
218            }
219            _ => {}
220        }
221        i += 1;
222    }
223    None
224}
225
226fn find_top_level_char(s: &str, needle: char) -> Option<usize> {
227    let mut depth = 0i32;
228    let mut in_string = false;
229    let bytes = s.as_bytes();
230    let mut i = 0usize;
231    while i < bytes.len() {
232        let b = bytes[i];
233        if in_string {
234            if b == b'\'' {
235                if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
236                    i += 2;
237                    continue;
238                }
239                in_string = false;
240            }
241            i += 1;
242            continue;
243        }
244        match b {
245            b'\'' => in_string = true,
246            b'(' => depth += 1,
247            b')' => depth -= 1,
248            _ => {
249                if depth == 0 && (b as char) == needle {
250                    return Some(i);
251                }
252            }
253        }
254        i += 1;
255    }
256    None
257}
258
259fn extract_first_string_literal(s: &str) -> Option<String> {
260    let s = s.trim();
261    let bytes = s.as_bytes();
262    if bytes.first().copied()? != b'\'' {
263        return None;
264    }
265    let mut out = String::new();
266    let mut i = 1usize;
267    while i < bytes.len() {
268        let b = bytes[i];
269        if b == b'\'' {
270            if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
271                out.push('\'');
272                i += 2;
273                continue;
274            }
275            return Some(out);
276        }
277        out.push(b as char);
278        i += 1;
279    }
280    None
281}
282
283fn is_sql_true_literal(s: &str) -> bool {
284    matches!(
285        s.trim().to_ascii_lowercase().as_str(),
286        "true" | "'true'" | "'true'::text" | "'true'::varchar"
287    )
288}
289
290fn is_sql_false_literal(s: &str) -> bool {
291    matches!(
292        s.trim().to_ascii_lowercase().as_str(),
293        "false" | "'false'" | "'false'::text" | "'false'::varchar"
294    )
295}
296
297fn is_sql_bool_string_literal(s: &str) -> bool {
298    is_sql_true_literal(s) || is_sql_false_literal(s)
299}
300
301/// Strip balanced outer parentheses from an expression.
302pub fn strip_outer_parens(s: &str) -> &str {
303    let s = s.trim();
304    if s.starts_with('(') && s.ends_with(')') {
305        // Check that parens are balanced (the opening matches the closing)
306        let mut depth = 0;
307        let bytes = s.as_bytes();
308        for (i, &b) in bytes.iter().enumerate() {
309            match b {
310                b'(' => depth += 1,
311                b')' => {
312                    depth -= 1;
313                    if depth == 0 && i < bytes.len() - 1 {
314                        // The opening paren closes before the end — not a wrapper
315                        return s;
316                    }
317                }
318                _ => {}
319            }
320        }
321        if depth == 0 {
322            return strip_outer_parens(&s[1..s.len() - 1]);
323        }
324    }
325    s
326}
327
328/// Find a top-level (not inside parentheses) occurrence of `op` in `s`.
329pub fn find_top_level_op(s: &str, op: &str) -> Option<usize> {
330    let mut depth = 0;
331    let bytes = s.as_bytes();
332    let op_bytes = op.as_bytes();
333    if bytes.len() < op_bytes.len() {
334        return None;
335    }
336    for i in 0..=bytes.len() - op_bytes.len() {
337        match bytes[i] {
338            b'(' => depth += 1,
339            b')' => depth -= 1,
340            _ => {}
341        }
342        if depth == 0 && &bytes[i..i + op_bytes.len()] == op_bytes {
343            return Some(i);
344        }
345    }
346    None
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_simple_tenant_check() {
355        let expr = parse_policy_expr("id = current_setting('app.current_tenant_id', true)::uuid")
356            .expect("expected tenant check parse");
357        match &expr {
358            Expr::Binary {
359                left, op, right, ..
360            } => {
361                assert!(matches!(op, BinaryOp::Eq));
362                assert!(matches!(left.as_ref(), Expr::Named(n) if n == "id"));
363                assert!(
364                    matches!(right.as_ref(), Expr::Cast { target_type, .. } if target_type == "uuid")
365                );
366            }
367            _ => panic!("Expected Binary, got {:?}", expr),
368        }
369    }
370
371    #[test]
372    fn test_or_combinator() {
373        let expr = parse_policy_expr(
374            "id = current_setting('app.op', true)::uuid OR current_setting('app.admin', true)::boolean = true",
375        )
376        .expect("expected OR parse");
377        assert!(matches!(
378            &expr,
379            Expr::Binary {
380                op: BinaryOp::Or,
381                ..
382            }
383        ));
384    }
385
386    #[test]
387    fn test_unsupported_expr_returns_error() {
388        let expr = parse_policy_expr("status = 'cancelled'::text");
389        assert!(expr.is_err());
390    }
391
392    #[test]
393    fn test_coalesce_current_setting_boolean_eq_true() {
394        let expr = parse_policy_expr(
395            "COALESCE(current_setting('app.is_super_admin'::text, true), 'false'::text) = 'true'::text",
396        )
397        .expect("expected COALESCE(current_setting(...)) parse");
398        match &expr {
399            Expr::Binary {
400                left, op, right, ..
401            } => {
402                assert!(matches!(op, BinaryOp::Eq));
403                assert!(matches!(left.as_ref(), Expr::Literal(Value::Bool(true))));
404                assert!(
405                    matches!(right.as_ref(), Expr::Cast { target_type, .. } if target_type == "boolean")
406                );
407            }
408            _ => panic!("Expected Binary, got {:?}", expr),
409        }
410    }
411
412    #[test]
413    fn test_nullif_current_setting_cast_uuid() {
414        let expr = parse_policy_expr(
415            "tenant_id = (NULLIF(current_setting('app.current_tenant_id'::text, true), ''::text))::uuid",
416        )
417        .expect("expected NULLIF(current_setting(...)) parse");
418        match &expr {
419            Expr::Binary {
420                left, op, right, ..
421            } => {
422                assert!(matches!(op, BinaryOp::Eq));
423                assert!(matches!(left.as_ref(), Expr::Named(n) if n == "tenant_id"));
424                assert!(
425                    matches!(right.as_ref(), Expr::Cast { target_type, .. } if target_type == "uuid")
426                );
427            }
428            _ => panic!("Expected Binary, got {:?}", expr),
429        }
430    }
431
432    #[test]
433    fn test_strip_outer_parens() {
434        assert_eq!(strip_outer_parens("(foo)"), "foo");
435        assert_eq!(strip_outer_parens("((foo))"), "foo");
436        assert_eq!(strip_outer_parens("(a) AND (b)"), "(a) AND (b)");
437    }
438}