1use crate::ast::Expr;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum PolicyTarget {
38 All,
40 Select,
42 Insert,
44 Update,
46 Delete,
48}
49
50impl std::fmt::Display for PolicyTarget {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 PolicyTarget::All => write!(f, "ALL"),
54 PolicyTarget::Select => write!(f, "SELECT"),
55 PolicyTarget::Insert => write!(f, "INSERT"),
56 PolicyTarget::Update => write!(f, "UPDATE"),
57 PolicyTarget::Delete => write!(f, "DELETE"),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum PolicyPermissiveness {
65 Permissive,
67 Restrictive,
69}
70
71impl std::fmt::Display for PolicyPermissiveness {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 PolicyPermissiveness::Permissive => write!(f, "PERMISSIVE"),
75 PolicyPermissiveness::Restrictive => write!(f, "RESTRICTIVE"),
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq)]
85pub struct RlsPolicy {
86 pub name: String,
88 pub table: String,
90 pub target: PolicyTarget,
92 pub permissiveness: PolicyPermissiveness,
94 pub using: Option<Expr>,
97 pub with_check: Option<Expr>,
100 pub role: Option<String>,
102}
103
104impl RlsPolicy {
105 pub fn create(name: impl Into<String>, table: impl Into<String>) -> Self {
112 Self {
113 name: name.into(),
114 table: table.into(),
115 target: PolicyTarget::All,
116 permissiveness: PolicyPermissiveness::Permissive,
117 using: None,
118 with_check: None,
119 role: None,
120 }
121 }
122
123 pub fn for_all(mut self) -> Self {
125 self.target = PolicyTarget::All;
126 self
127 }
128
129 pub fn for_select(mut self) -> Self {
131 self.target = PolicyTarget::Select;
132 self
133 }
134
135 pub fn for_insert(mut self) -> Self {
137 self.target = PolicyTarget::Insert;
138 self
139 }
140
141 pub fn for_update(mut self) -> Self {
143 self.target = PolicyTarget::Update;
144 self
145 }
146
147 pub fn for_delete(mut self) -> Self {
149 self.target = PolicyTarget::Delete;
150 self
151 }
152
153 pub fn restrictive(mut self) -> Self {
155 self.permissiveness = PolicyPermissiveness::Restrictive;
156 self
157 }
158
159 pub fn using(mut self, expr: Expr) -> Self {
162 self.using = Some(expr);
163 self
164 }
165
166 pub fn with_check(mut self, expr: Expr) -> Self {
169 self.with_check = Some(expr);
170 self
171 }
172
173 pub fn to_role(mut self, role: impl Into<String>) -> Self {
175 self.role = Some(role.into());
176 self
177 }
178}
179
180pub fn tenant_check(
194 column: impl Into<String>,
195 session_var: impl Into<String>,
196 cast_type: impl Into<String>,
197) -> Expr {
198 use crate::ast::{BinaryOp, Value};
199
200 Expr::Binary {
201 left: Box::new(Expr::Named(column.into())),
202 op: BinaryOp::Eq,
203 right: Box::new(Expr::Cast {
204 expr: Box::new(Expr::FunctionCall {
205 name: "current_setting".into(),
206 args: vec![Expr::Literal(Value::String(session_var.into()))],
207 alias: None,
208 }),
209 target_type: cast_type.into(),
210 alias: None,
211 }),
212 alias: None,
213 }
214}
215
216pub fn session_bool_check(session_var: impl Into<String>) -> Expr {
230 use crate::ast::{BinaryOp, Value};
231
232 Expr::Binary {
233 left: Box::new(Expr::Cast {
234 expr: Box::new(Expr::FunctionCall {
235 name: "current_setting".into(),
236 args: vec![Expr::Literal(Value::String(session_var.into()))],
237 alias: None,
238 }),
239 target_type: "boolean".into(),
240 alias: None,
241 }),
242 op: BinaryOp::Eq,
243 right: Box::new(Expr::Literal(Value::Bool(true))),
244 alias: None,
245 }
246}
247
248pub fn or(left: Expr, right: Expr) -> Expr {
252 use crate::ast::BinaryOp;
253
254 Expr::Binary {
255 left: Box::new(left),
256 op: BinaryOp::Or,
257 right: Box::new(right),
258 alias: None,
259 }
260}
261
262pub fn and(left: Expr, right: Expr) -> Expr {
264 use crate::ast::BinaryOp;
265
266 Expr::Binary {
267 left: Box::new(left),
268 op: BinaryOp::And,
269 right: Box::new(right),
270 alias: None,
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::ast::BinaryOp;
278
279 #[test]
280 fn test_policy_builder() {
281 let policy = RlsPolicy::create("orders_isolation", "orders")
282 .for_all()
283 .using(tenant_check(
284 "operator_id",
285 "app.current_operator_id",
286 "uuid",
287 ))
288 .with_check(tenant_check(
289 "operator_id",
290 "app.current_operator_id",
291 "uuid",
292 ));
293
294 assert_eq!(policy.name, "orders_isolation");
295 assert_eq!(policy.table, "orders");
296 assert_eq!(policy.target, PolicyTarget::All);
297 assert!(policy.using.is_some());
298 assert!(policy.with_check.is_some());
299 }
300
301 #[test]
302 fn test_policy_restrictive() {
303 let policy = RlsPolicy::create("admin_only", "secrets")
304 .for_select()
305 .restrictive()
306 .to_role("app_user");
307
308 assert_eq!(policy.target, PolicyTarget::Select);
309 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
310 assert_eq!(policy.role.as_deref(), Some("app_user"));
311 }
312
313 #[test]
314 fn test_tenant_check_helper() {
315 let expr = tenant_check("operator_id", "app.current_operator_id", "uuid");
316
317 let Expr::Binary {
318 left, op, right, ..
319 } = &expr
320 else {
321 panic!("Expected Binary, got {expr:?}");
322 };
323 assert_eq!(*op, BinaryOp::Eq);
324
325 let Expr::Named(n) = left.as_ref() else {
326 panic!("Expected Named, got {left:?}");
327 };
328 assert_eq!(n, "operator_id");
329
330 let Expr::Cast {
331 expr: cast_expr,
332 target_type,
333 ..
334 } = right.as_ref()
335 else {
336 panic!("Expected Cast, got {right:?}");
337 };
338 assert_eq!(target_type, "uuid");
339
340 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
341 panic!("Expected FunctionCall, got {cast_expr:?}");
342 };
343 assert_eq!(name, "current_setting");
344 assert_eq!(args.len(), 1);
345 }
346
347 #[test]
348 fn test_super_admin_bypass() {
349 let expr = or(
350 tenant_check("operator_id", "app.current_operator_id", "uuid"),
351 session_bool_check("app.is_super_admin"),
352 );
353
354 assert!(
355 matches!(
356 &expr,
357 Expr::Binary {
358 op: BinaryOp::Or,
359 ..
360 }
361 ),
362 "Expected Binary OR, got {expr:?}"
363 );
364 }
365
366 #[test]
367 fn test_and_combinator() {
368 let expr = and(
369 tenant_check("operator_id", "app.current_operator_id", "uuid"),
370 tenant_check("agent_id", "app.current_agent_id", "uuid"),
371 );
372
373 assert!(
374 matches!(
375 &expr,
376 Expr::Binary {
377 op: BinaryOp::And,
378 ..
379 }
380 ),
381 "Expected Binary AND, got {expr:?}"
382 );
383 }
384
385 #[test]
386 fn test_policy_target_display() {
387 assert_eq!(PolicyTarget::All.to_string(), "ALL");
388 assert_eq!(PolicyTarget::Select.to_string(), "SELECT");
389 assert_eq!(PolicyTarget::Insert.to_string(), "INSERT");
390 assert_eq!(PolicyTarget::Update.to_string(), "UPDATE");
391 assert_eq!(PolicyTarget::Delete.to_string(), "DELETE");
392 }
393}