1use crate::ast::Expr;
34use serde::{Deserialize, Serialize};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum PolicyTarget {
39 All,
41 Select,
43 Insert,
45 Update,
47 Delete,
49}
50
51impl std::fmt::Display for PolicyTarget {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 PolicyTarget::All => write!(f, "ALL"),
55 PolicyTarget::Select => write!(f, "SELECT"),
56 PolicyTarget::Insert => write!(f, "INSERT"),
57 PolicyTarget::Update => write!(f, "UPDATE"),
58 PolicyTarget::Delete => write!(f, "DELETE"),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum PolicyPermissiveness {
66 Permissive,
68 Restrictive,
70}
71
72impl std::fmt::Display for PolicyPermissiveness {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 PolicyPermissiveness::Permissive => write!(f, "PERMISSIVE"),
76 PolicyPermissiveness::Restrictive => write!(f, "RESTRICTIVE"),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct RlsPolicy {
87 pub name: String,
89 pub table: String,
91 pub target: PolicyTarget,
93 pub permissiveness: PolicyPermissiveness,
95 pub using: Option<Expr>,
98 pub with_check: Option<Expr>,
101 pub role: Option<String>,
103}
104
105impl RlsPolicy {
106 pub fn create(name: impl Into<String>, table: impl Into<String>) -> Self {
113 Self {
114 name: name.into(),
115 table: table.into(),
116 target: PolicyTarget::All,
117 permissiveness: PolicyPermissiveness::Permissive,
118 using: None,
119 with_check: None,
120 role: None,
121 }
122 }
123
124 pub fn for_all(mut self) -> Self {
126 self.target = PolicyTarget::All;
127 self
128 }
129
130 pub fn for_select(mut self) -> Self {
132 self.target = PolicyTarget::Select;
133 self
134 }
135
136 pub fn for_insert(mut self) -> Self {
138 self.target = PolicyTarget::Insert;
139 self
140 }
141
142 pub fn for_update(mut self) -> Self {
144 self.target = PolicyTarget::Update;
145 self
146 }
147
148 pub fn for_delete(mut self) -> Self {
150 self.target = PolicyTarget::Delete;
151 self
152 }
153
154 pub fn restrictive(mut self) -> Self {
156 self.permissiveness = PolicyPermissiveness::Restrictive;
157 self
158 }
159
160 pub fn using(mut self, expr: Expr) -> Self {
163 self.using = Some(expr);
164 self
165 }
166
167 pub fn with_check(mut self, expr: Expr) -> Self {
170 self.with_check = Some(expr);
171 self
172 }
173
174 pub fn to_role(mut self, role: impl Into<String>) -> Self {
176 self.role = Some(role.into());
177 self
178 }
179}
180
181pub fn tenant_check(
195 column: impl Into<String>,
196 session_var: impl Into<String>,
197 cast_type: impl Into<String>,
198) -> Expr {
199 use crate::ast::{BinaryOp, Value};
200
201 Expr::Binary {
202 left: Box::new(Expr::Named(column.into())),
203 op: BinaryOp::Eq,
204 right: Box::new(Expr::Cast {
205 expr: Box::new(Expr::FunctionCall {
206 name: "current_setting".into(),
207 args: vec![Expr::Literal(Value::String(session_var.into()))],
208 alias: None,
209 }),
210 target_type: cast_type.into(),
211 alias: None,
212 }),
213 alias: None,
214 }
215}
216
217pub fn session_bool_check(session_var: impl Into<String>) -> Expr {
231 use crate::ast::{BinaryOp, Value};
232
233 Expr::Binary {
234 left: Box::new(Expr::Cast {
235 expr: Box::new(Expr::FunctionCall {
236 name: "current_setting".into(),
237 args: vec![Expr::Literal(Value::String(session_var.into()))],
238 alias: None,
239 }),
240 target_type: "boolean".into(),
241 alias: None,
242 }),
243 op: BinaryOp::Eq,
244 right: Box::new(Expr::Literal(Value::Bool(true))),
245 alias: None,
246 }
247}
248
249pub fn or(left: Expr, right: Expr) -> Expr {
253 use crate::ast::BinaryOp;
254
255 Expr::Binary {
256 left: Box::new(left),
257 op: BinaryOp::Or,
258 right: Box::new(right),
259 alias: None,
260 }
261}
262
263pub fn and(left: Expr, right: Expr) -> Expr {
265 use crate::ast::BinaryOp;
266
267 Expr::Binary {
268 left: Box::new(left),
269 op: BinaryOp::And,
270 right: Box::new(right),
271 alias: None,
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::ast::BinaryOp;
279
280 #[test]
281 fn test_policy_builder() {
282 let policy = RlsPolicy::create("orders_isolation", "orders")
283 .for_all()
284 .using(tenant_check(
285 "operator_id",
286 "app.current_operator_id",
287 "uuid",
288 ))
289 .with_check(tenant_check(
290 "operator_id",
291 "app.current_operator_id",
292 "uuid",
293 ));
294
295 assert_eq!(policy.name, "orders_isolation");
296 assert_eq!(policy.table, "orders");
297 assert_eq!(policy.target, PolicyTarget::All);
298 assert!(policy.using.is_some());
299 assert!(policy.with_check.is_some());
300 }
301
302 #[test]
303 fn test_policy_restrictive() {
304 let policy = RlsPolicy::create("admin_only", "secrets")
305 .for_select()
306 .restrictive()
307 .to_role("app_user");
308
309 assert_eq!(policy.target, PolicyTarget::Select);
310 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
311 assert_eq!(policy.role.as_deref(), Some("app_user"));
312 }
313
314 #[test]
315 fn test_tenant_check_helper() {
316 let expr = tenant_check("operator_id", "app.current_operator_id", "uuid");
317
318 let Expr::Binary {
319 left, op, right, ..
320 } = &expr
321 else {
322 panic!("Expected Binary, got {expr:?}");
323 };
324 assert_eq!(*op, BinaryOp::Eq);
325
326 let Expr::Named(n) = left.as_ref() else {
327 panic!("Expected Named, got {left:?}");
328 };
329 assert_eq!(n, "operator_id");
330
331 let Expr::Cast {
332 expr: cast_expr,
333 target_type,
334 ..
335 } = right.as_ref()
336 else {
337 panic!("Expected Cast, got {right:?}");
338 };
339 assert_eq!(target_type, "uuid");
340
341 let Expr::FunctionCall { name, args, .. } = cast_expr.as_ref() else {
342 panic!("Expected FunctionCall, got {cast_expr:?}");
343 };
344 assert_eq!(name, "current_setting");
345 assert_eq!(args.len(), 1);
346 }
347
348 #[test]
349 fn test_super_admin_bypass() {
350 let expr = or(
351 tenant_check("operator_id", "app.current_operator_id", "uuid"),
352 session_bool_check("app.is_super_admin"),
353 );
354
355 assert!(
356 matches!(
357 &expr,
358 Expr::Binary {
359 op: BinaryOp::Or,
360 ..
361 }
362 ),
363 "Expected Binary OR, got {expr:?}"
364 );
365 }
366
367 #[test]
368 fn test_and_combinator() {
369 let expr = and(
370 tenant_check("operator_id", "app.current_operator_id", "uuid"),
371 tenant_check("agent_id", "app.current_agent_id", "uuid"),
372 );
373
374 assert!(
375 matches!(
376 &expr,
377 Expr::Binary {
378 op: BinaryOp::And,
379 ..
380 }
381 ),
382 "Expected Binary AND, got {expr:?}"
383 );
384 }
385
386 #[test]
387 fn test_policy_target_display() {
388 assert_eq!(PolicyTarget::All.to_string(), "ALL");
389 assert_eq!(PolicyTarget::Select.to_string(), "SELECT");
390 assert_eq!(PolicyTarget::Insert.to_string(), "INSERT");
391 assert_eq!(PolicyTarget::Update.to_string(), "UPDATE");
392 assert_eq!(PolicyTarget::Delete.to_string(), "DELETE");
393 }
394}