1use crate::ast::Expr;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
37pub enum PolicyTarget {
38 All,
40 Select,
42 Insert,
44 Update,
46 Delete,
48}
49
50impl std::fmt::Display for PolicyTarget {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 PolicyTarget::All => write!(f, "ALL"),
54 PolicyTarget::Select => write!(f, "SELECT"),
55 PolicyTarget::Insert => write!(f, "INSERT"),
56 PolicyTarget::Update => write!(f, "UPDATE"),
57 PolicyTarget::Delete => write!(f, "DELETE"),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
64pub enum PolicyPermissiveness {
65 Permissive,
67 Restrictive,
69}
70
71impl std::fmt::Display for PolicyPermissiveness {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 PolicyPermissiveness::Permissive => write!(f, "PERMISSIVE"),
75 PolicyPermissiveness::Restrictive => write!(f, "RESTRICTIVE"),
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
85pub struct RlsPolicy {
86 pub name: String,
88 pub table: String,
90 pub target: PolicyTarget,
92 pub permissiveness: PolicyPermissiveness,
94 pub using: Option<Expr>,
97 pub with_check: Option<Expr>,
100 pub role: Option<String>,
102}
103
104impl RlsPolicy {
105 pub fn create(name: impl Into<String>, table: impl Into<String>) -> Self {
112 Self {
113 name: name.into(),
114 table: table.into(),
115 target: PolicyTarget::All,
116 permissiveness: PolicyPermissiveness::Permissive,
117 using: None,
118 with_check: None,
119 role: None,
120 }
121 }
122
123 pub fn for_all(mut self) -> Self {
125 self.target = PolicyTarget::All;
126 self
127 }
128
129 pub fn for_select(mut self) -> Self {
131 self.target = PolicyTarget::Select;
132 self
133 }
134
135 pub fn for_insert(mut self) -> Self {
137 self.target = PolicyTarget::Insert;
138 self
139 }
140
141 pub fn for_update(mut self) -> Self {
143 self.target = PolicyTarget::Update;
144 self
145 }
146
147 pub fn for_delete(mut self) -> Self {
149 self.target = PolicyTarget::Delete;
150 self
151 }
152
153 pub fn restrictive(mut self) -> Self {
155 self.permissiveness = PolicyPermissiveness::Restrictive;
156 self
157 }
158
159 pub fn using(mut self, expr: Expr) -> Self {
162 self.using = Some(expr);
163 self
164 }
165
166 pub fn with_check(mut self, expr: Expr) -> Self {
169 self.with_check = Some(expr);
170 self
171 }
172
173 pub fn to_role(mut self, role: impl Into<String>) -> Self {
175 self.role = Some(role.into());
176 self
177 }
178}
179
180pub fn tenant_check(
194 column: impl Into<String>,
195 session_var: impl Into<String>,
196 cast_type: impl Into<String>,
197) -> Expr {
198 use crate::ast::{BinaryOp, Value};
199
200 Expr::Binary {
201 left: Box::new(Expr::Named(column.into())),
202 op: BinaryOp::Eq,
203 right: Box::new(Expr::Cast {
204 expr: Box::new(Expr::FunctionCall {
205 name: "current_setting".into(),
206 args: vec![Expr::Literal(Value::String(session_var.into()))],
207 alias: None,
208 }),
209 target_type: cast_type.into(),
210 alias: None,
211 }),
212 alias: None,
213 }
214}
215
216pub fn session_bool_check(session_var: impl Into<String>) -> Expr {
230 use crate::ast::{BinaryOp, Value};
231
232 Expr::Binary {
233 left: Box::new(Expr::Cast {
234 expr: Box::new(Expr::FunctionCall {
235 name: "current_setting".into(),
236 args: vec![Expr::Literal(Value::String(session_var.into()))],
237 alias: None,
238 }),
239 target_type: "boolean".into(),
240 alias: None,
241 }),
242 op: BinaryOp::Eq,
243 right: Box::new(Expr::Literal(Value::Bool(true))),
244 alias: None,
245 }
246}
247
248pub fn or(left: Expr, right: Expr) -> Expr {
252 use crate::ast::BinaryOp;
253
254 Expr::Binary {
255 left: Box::new(left),
256 op: BinaryOp::Or,
257 right: Box::new(right),
258 alias: None,
259 }
260}
261
262pub fn and(left: Expr, right: Expr) -> Expr {
264 use crate::ast::BinaryOp;
265
266 Expr::Binary {
267 left: Box::new(left),
268 op: BinaryOp::And,
269 right: Box::new(right),
270 alias: None,
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::ast::BinaryOp;
278
279 #[test]
280 fn test_policy_builder() {
281 let policy = RlsPolicy::create("orders_isolation", "orders")
282 .for_all()
283 .using(tenant_check("tenant_id", "app.current_tenant_id", "uuid"))
284 .with_check(tenant_check("tenant_id", "app.current_tenant_id", "uuid"));
285
286 assert_eq!(policy.name, "orders_isolation");
287 assert_eq!(policy.table, "orders");
288 assert_eq!(policy.target, PolicyTarget::All);
289 assert!(policy.using.is_some());
290 assert!(policy.with_check.is_some());
291 }
292
293 #[test]
294 fn test_policy_restrictive() {
295 let policy = RlsPolicy::create("admin_only", "secrets")
296 .for_select()
297 .restrictive()
298 .to_role("app_user");
299
300 assert_eq!(policy.target, PolicyTarget::Select);
301 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
302 assert_eq!(policy.role.as_deref(), Some("app_user"));
303 }
304
305 #[test]
306 fn test_tenant_check_helper() {
307 let expr = tenant_check("tenant_id", "app.current_tenant_id", "uuid");
308
309 let Expr::Binary {
310 left, op, right, ..
311 } = &expr
312 else {
313 panic!("Expected Binary, got {expr:?}");
314 };
315 assert_eq!(*op, BinaryOp::Eq);
316
317 let Expr::Named(n) = left.as_ref() else {
318 panic!("Expected Named, got {left:?}");
319 };
320 assert_eq!(n, "tenant_id");
321
322 let Expr::Cast {
323 expr: cast_expr,
324 target_type,
325 ..
326 } = right.as_ref()
327 else {
328 panic!("Expected Cast, got {right:?}");
329 };
330 assert_eq!(target_type, "uuid");
331
332 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
333 panic!("Expected FunctionCall, got {cast_expr:?}");
334 };
335 assert_eq!(name, "current_setting");
336 assert_eq!(args.len(), 1);
337 }
338
339 #[test]
340 fn test_super_admin_bypass() {
341 let expr = or(
342 tenant_check("tenant_id", "app.current_tenant_id", "uuid"),
343 session_bool_check("app.is_super_admin"),
344 );
345
346 assert!(
347 matches!(
348 &expr,
349 Expr::Binary {
350 op: BinaryOp::Or,
351 ..
352 }
353 ),
354 "Expected Binary OR, got {expr:?}"
355 );
356 }
357
358 #[test]
359 fn test_and_combinator() {
360 let expr = and(
361 tenant_check("tenant_id", "app.current_tenant_id", "uuid"),
362 tenant_check("agent_id", "app.current_agent_id", "uuid"),
363 );
364
365 assert!(
366 matches!(
367 &expr,
368 Expr::Binary {
369 op: BinaryOp::And,
370 ..
371 }
372 ),
373 "Expected Binary AND, got {expr:?}"
374 );
375 }
376
377 #[test]
378 fn test_policy_target_display() {
379 assert_eq!(PolicyTarget::All.to_string(), "ALL");
380 assert_eq!(PolicyTarget::Select.to_string(), "SELECT");
381 assert_eq!(PolicyTarget::Insert.to_string(), "INSERT");
382 assert_eq!(PolicyTarget::Update.to_string(), "UPDATE");
383 assert_eq!(PolicyTarget::Delete.to_string(), "DELETE");
384 }
385}