1use crate::ast::expr::Expr;
10use crate::ast::{BinaryOp, Value};
11
12pub fn parse_policy_expr(sql: &str) -> Result<Expr, String> {
21 let s = sql.trim();
22
23 let s = strip_outer_parens(s);
25
26 if let Some(pos) = find_top_level_op(s, " OR ") {
28 let left = parse_policy_expr(&s[..pos])?;
29 let right = parse_policy_expr(&s[pos + 4..])?;
30 return Ok(Expr::Binary {
31 left: Box::new(left),
32 op: BinaryOp::Or,
33 right: Box::new(right),
34 alias: None,
35 });
36 }
37
38 if let Some(pos) = find_top_level_op(s, " AND ") {
40 let left = parse_policy_expr(&s[..pos])?;
41 let right = parse_policy_expr(&s[pos + 5..])?;
42 return Ok(Expr::Binary {
43 left: Box::new(left),
44 op: BinaryOp::And,
45 right: Box::new(right),
46 alias: None,
47 });
48 }
49
50 if let Some(eq_pos) = find_top_level_op(s, " = ") {
53 let lhs = s[..eq_pos].trim();
54 let rhs = s[eq_pos + 3..].trim();
55
56 if let Some(expr) = try_parse_tenant_check(lhs, rhs) {
58 return Ok(expr);
59 }
60 if let Some(expr) = try_parse_tenant_check(rhs, lhs) {
62 return Ok(expr);
64 }
65 }
66
67 Err(format!("unsupported policy expression: {}", s))
68}
69
70fn try_parse_tenant_check(col_side: &str, setting_side: &str) -> Option<Expr> {
72 let (session_var, cast_type) = parse_setting_expr(setting_side)?;
73 let left = parse_policy_lhs(col_side);
74
75 Some(Expr::Binary {
76 left: Box::new(left),
77 op: BinaryOp::Eq,
78 right: Box::new(Expr::Cast {
79 expr: Box::new(Expr::FunctionCall {
80 name: "current_setting".into(),
81 args: vec![Expr::Literal(Value::String(session_var))],
82 alias: None,
83 }),
84 target_type: cast_type,
85 alias: None,
86 }),
87 alias: None,
88 })
89}
90
91fn parse_policy_lhs(col_side: &str) -> Expr {
92 let lhs = strip_outer_parens(col_side).trim();
93 if is_sql_true_literal(lhs) {
94 return Expr::Literal(Value::Bool(true));
95 }
96 if is_sql_false_literal(lhs) {
97 return Expr::Literal(Value::Bool(false));
98 }
99 Expr::Named(lhs.to_string())
100}
101
102fn parse_setting_expr(setting_side: &str) -> Option<(String, String)> {
103 let mut normalized = strip_outer_parens(setting_side).trim().to_string();
104 loop {
106 let candidate = normalized.trim();
107 if !candidate.starts_with('(') {
108 break;
109 }
110 let Some(close_idx) = find_matching_paren(candidate, 0) else {
111 break;
112 };
113 let rest = candidate[close_idx + 1..].trim();
114 if !rest.starts_with("::") {
115 break;
116 }
117 let inner = candidate[1..close_idx].trim();
118 normalized = format!("{inner}{rest}");
119 }
120 let s = normalized.trim();
121
122 if let Some((session_var, rest)) = parse_current_setting_call(s) {
124 let cast = parse_cast_suffix(rest).unwrap_or_else(|| "text".to_string());
125 return Some((session_var, cast));
126 }
127
128 if let Some((args, rest)) = parse_function_args_and_rest_ci(s, "NULLIF") {
130 let (arg1, _arg2) = split_args2(args)?;
131 let (session_var, mut cast) = parse_setting_expr(arg1.trim())?;
132 if let Some(parsed_cast) = parse_cast_suffix(rest) {
133 cast = parsed_cast;
134 }
135 return Some((session_var, cast));
136 }
137
138 if let Some((args, rest)) = parse_function_args_and_rest_ci(s, "COALESCE") {
140 let (arg1, arg2) = split_args2(args)?;
141 let (session_var, mut cast) = parse_setting_expr(arg1.trim())?;
142 if let Some(parsed_cast) = parse_cast_suffix(rest) {
143 cast = parsed_cast;
144 } else if is_sql_bool_string_literal(arg2.trim()) {
145 cast = "boolean".to_string();
147 }
148 return Some((session_var, cast));
149 }
150
151 None
152}
153
154fn parse_cast_suffix(rest: &str) -> Option<String> {
155 let tail = strip_outer_parens(rest).trim();
156 tail.strip_prefix("::").map(|s| s.trim().to_string())
157}
158
159fn split_args2(args: &str) -> Option<(&str, &str)> {
160 let idx = find_top_level_char(args, ',')?;
161 Some((&args[..idx], &args[idx + 1..]))
162}
163
164fn parse_function_args_and_rest_ci<'a>(s: &'a str, fn_name: &str) -> Option<(&'a str, &'a str)> {
165 let s = s.trim();
166 let prefix = format!("{fn_name}(");
167 if !starts_with_ci(s, &prefix) {
168 return None;
169 }
170 let open_idx = fn_name.len();
171 let close_idx = find_matching_paren(s, open_idx)?;
172 let args = &s[open_idx + 1..close_idx];
173 let rest = &s[close_idx + 1..];
174 Some((args, rest))
175}
176
177fn parse_current_setting_call(expr: &str) -> Option<(String, &str)> {
178 let (args, rest) = parse_function_args_and_rest_ci(expr, "current_setting")?;
179 let session_var = extract_first_string_literal(args)?;
180 Some((session_var, rest))
181}
182
183fn starts_with_ci(s: &str, prefix: &str) -> bool {
184 s.get(..prefix.len())
185 .is_some_and(|h| h.eq_ignore_ascii_case(prefix))
186}
187
188fn find_matching_paren(s: &str, open_idx: usize) -> Option<usize> {
189 let bytes = s.as_bytes();
190 if *bytes.get(open_idx)? != b'(' {
191 return None;
192 }
193 let mut depth = 0usize;
194 let mut i = open_idx;
195 let mut in_string = false;
196 while i < bytes.len() {
197 let b = bytes[i];
198 if in_string {
199 if b == b'\'' {
200 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
202 i += 2;
203 continue;
204 }
205 in_string = false;
206 }
207 i += 1;
208 continue;
209 }
210 match b {
211 b'\'' => in_string = true,
212 b'(' => depth += 1,
213 b')' => {
214 depth = depth.saturating_sub(1);
215 if depth == 0 {
216 return Some(i);
217 }
218 }
219 _ => {}
220 }
221 i += 1;
222 }
223 None
224}
225
226fn find_top_level_char(s: &str, needle: char) -> Option<usize> {
227 let mut depth = 0i32;
228 let mut in_string = false;
229 let bytes = s.as_bytes();
230 let mut i = 0usize;
231 while i < bytes.len() {
232 let b = bytes[i];
233 if in_string {
234 if b == b'\'' {
235 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
236 i += 2;
237 continue;
238 }
239 in_string = false;
240 }
241 i += 1;
242 continue;
243 }
244 match b {
245 b'\'' => in_string = true,
246 b'(' => depth += 1,
247 b')' => depth -= 1,
248 _ => {
249 if depth == 0 && (b as char) == needle {
250 return Some(i);
251 }
252 }
253 }
254 i += 1;
255 }
256 None
257}
258
259fn extract_first_string_literal(s: &str) -> Option<String> {
260 let s = s.trim();
261 let bytes = s.as_bytes();
262 if bytes.first().copied()? != b'\'' {
263 return None;
264 }
265 let mut out = String::new();
266 let mut i = 1usize;
267 while i < bytes.len() {
268 let b = bytes[i];
269 if b == b'\'' {
270 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
271 out.push('\'');
272 i += 2;
273 continue;
274 }
275 return Some(out);
276 }
277 out.push(b as char);
278 i += 1;
279 }
280 None
281}
282
283fn is_sql_true_literal(s: &str) -> bool {
284 matches!(
285 s.trim().to_ascii_lowercase().as_str(),
286 "true" | "'true'" | "'true'::text" | "'true'::varchar"
287 )
288}
289
290fn is_sql_false_literal(s: &str) -> bool {
291 matches!(
292 s.trim().to_ascii_lowercase().as_str(),
293 "false" | "'false'" | "'false'::text" | "'false'::varchar"
294 )
295}
296
297fn is_sql_bool_string_literal(s: &str) -> bool {
298 is_sql_true_literal(s) || is_sql_false_literal(s)
299}
300
301pub fn strip_outer_parens(s: &str) -> &str {
303 let s = s.trim();
304 if s.starts_with('(') && s.ends_with(')') {
305 let mut depth = 0;
307 let bytes = s.as_bytes();
308 for (i, &b) in bytes.iter().enumerate() {
309 match b {
310 b'(' => depth += 1,
311 b')' => {
312 depth -= 1;
313 if depth == 0 && i < bytes.len() - 1 {
314 return s;
316 }
317 }
318 _ => {}
319 }
320 }
321 if depth == 0 {
322 return strip_outer_parens(&s[1..s.len() - 1]);
323 }
324 }
325 s
326}
327
328pub fn find_top_level_op(s: &str, op: &str) -> Option<usize> {
330 let mut depth = 0;
331 let bytes = s.as_bytes();
332 let op_bytes = op.as_bytes();
333 if bytes.len() < op_bytes.len() {
334 return None;
335 }
336 for i in 0..=bytes.len() - op_bytes.len() {
337 match bytes[i] {
338 b'(' => depth += 1,
339 b')' => depth -= 1,
340 _ => {}
341 }
342 if depth == 0 && &bytes[i..i + op_bytes.len()] == op_bytes {
343 return Some(i);
344 }
345 }
346 None
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_simple_tenant_check() {
355 let expr = parse_policy_expr("id = current_setting('app.current_tenant_id', true)::uuid")
356 .expect("expected tenant check parse");
357 match &expr {
358 Expr::Binary {
359 left, op, right, ..
360 } => {
361 assert!(matches!(op, BinaryOp::Eq));
362 assert!(matches!(left.as_ref(), Expr::Named(n) if n == "id"));
363 assert!(
364 matches!(right.as_ref(), Expr::Cast { target_type, .. } if target_type == "uuid")
365 );
366 }
367 _ => panic!("Expected Binary, got {:?}", expr),
368 }
369 }
370
371 #[test]
372 fn test_or_combinator() {
373 let expr = parse_policy_expr(
374 "id = current_setting('app.op', true)::uuid OR current_setting('app.admin', true)::boolean = true",
375 )
376 .expect("expected OR parse");
377 assert!(matches!(
378 &expr,
379 Expr::Binary {
380 op: BinaryOp::Or,
381 ..
382 }
383 ));
384 }
385
386 #[test]
387 fn test_unsupported_expr_returns_error() {
388 let expr = parse_policy_expr("status = 'cancelled'::text");
389 assert!(expr.is_err());
390 }
391
392 #[test]
393 fn test_coalesce_current_setting_boolean_eq_true() {
394 let expr = parse_policy_expr(
395 "COALESCE(current_setting('app.is_super_admin'::text, true), 'false'::text) = 'true'::text",
396 )
397 .expect("expected COALESCE(current_setting(...)) parse");
398 match &expr {
399 Expr::Binary {
400 left, op, right, ..
401 } => {
402 assert!(matches!(op, BinaryOp::Eq));
403 assert!(matches!(left.as_ref(), Expr::Literal(Value::Bool(true))));
404 assert!(
405 matches!(right.as_ref(), Expr::Cast { target_type, .. } if target_type == "boolean")
406 );
407 }
408 _ => panic!("Expected Binary, got {:?}", expr),
409 }
410 }
411
412 #[test]
413 fn test_nullif_current_setting_cast_uuid() {
414 let expr = parse_policy_expr(
415 "tenant_id = (NULLIF(current_setting('app.current_tenant_id'::text, true), ''::text))::uuid",
416 )
417 .expect("expected NULLIF(current_setting(...)) parse");
418 match &expr {
419 Expr::Binary {
420 left, op, right, ..
421 } => {
422 assert!(matches!(op, BinaryOp::Eq));
423 assert!(matches!(left.as_ref(), Expr::Named(n) if n == "tenant_id"));
424 assert!(
425 matches!(right.as_ref(), Expr::Cast { target_type, .. } if target_type == "uuid")
426 );
427 }
428 _ => panic!("Expected Binary, got {:?}", expr),
429 }
430 }
431
432 #[test]
433 fn test_strip_outer_parens() {
434 assert_eq!(strip_outer_parens("(foo)"), "foo");
435 assert_eq!(strip_outer_parens("((foo))"), "foo");
436 assert_eq!(strip_outer_parens("(a) AND (b)"), "(a) AND (b)");
437 }
438}