1use crate::ast::Expr;
34use serde::{Serialize, Deserialize};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum PolicyTarget {
39 All,
40 Select,
41 Insert,
42 Update,
43 Delete,
44}
45
46impl std::fmt::Display for PolicyTarget {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 PolicyTarget::All => write!(f, "ALL"),
50 PolicyTarget::Select => write!(f, "SELECT"),
51 PolicyTarget::Insert => write!(f, "INSERT"),
52 PolicyTarget::Update => write!(f, "UPDATE"),
53 PolicyTarget::Delete => write!(f, "DELETE"),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum PolicyPermissiveness {
61 Permissive,
63 Restrictive,
65}
66
67impl std::fmt::Display for PolicyPermissiveness {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 PolicyPermissiveness::Permissive => write!(f, "PERMISSIVE"),
71 PolicyPermissiveness::Restrictive => write!(f, "RESTRICTIVE"),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct RlsPolicy {
82 pub name: String,
84 pub table: String,
86 pub target: PolicyTarget,
88 pub permissiveness: PolicyPermissiveness,
90 pub using: Option<Expr>,
93 pub with_check: Option<Expr>,
96 pub role: Option<String>,
98}
99
100impl RlsPolicy {
101 pub fn create(name: impl Into<String>, table: impl Into<String>) -> Self {
108 Self {
109 name: name.into(),
110 table: table.into(),
111 target: PolicyTarget::All,
112 permissiveness: PolicyPermissiveness::Permissive,
113 using: None,
114 with_check: None,
115 role: None,
116 }
117 }
118
119 pub fn for_all(mut self) -> Self {
121 self.target = PolicyTarget::All;
122 self
123 }
124
125 pub fn for_select(mut self) -> Self {
127 self.target = PolicyTarget::Select;
128 self
129 }
130
131 pub fn for_insert(mut self) -> Self {
133 self.target = PolicyTarget::Insert;
134 self
135 }
136
137 pub fn for_update(mut self) -> Self {
139 self.target = PolicyTarget::Update;
140 self
141 }
142
143 pub fn for_delete(mut self) -> Self {
145 self.target = PolicyTarget::Delete;
146 self
147 }
148
149 pub fn restrictive(mut self) -> Self {
151 self.permissiveness = PolicyPermissiveness::Restrictive;
152 self
153 }
154
155 pub fn using(mut self, expr: Expr) -> Self {
158 self.using = Some(expr);
159 self
160 }
161
162 pub fn with_check(mut self, expr: Expr) -> Self {
165 self.with_check = Some(expr);
166 self
167 }
168
169 pub fn to_role(mut self, role: impl Into<String>) -> Self {
171 self.role = Some(role.into());
172 self
173 }
174}
175
176pub fn tenant_check(
190 column: impl Into<String>,
191 session_var: impl Into<String>,
192 cast_type: impl Into<String>,
193) -> Expr {
194 use crate::ast::{BinaryOp, Value};
195
196 Expr::Binary {
197 left: Box::new(Expr::Named(column.into())),
198 op: BinaryOp::Eq,
199 right: Box::new(Expr::Cast {
200 expr: Box::new(Expr::FunctionCall {
201 name: "current_setting".into(),
202 args: vec![Expr::Literal(Value::String(session_var.into()))],
203 alias: None,
204 }),
205 target_type: cast_type.into(),
206 alias: None,
207 }),
208 alias: None,
209 }
210}
211
212pub fn session_bool_check(session_var: impl Into<String>) -> Expr {
226 use crate::ast::{BinaryOp, Value};
227
228 Expr::Binary {
229 left: Box::new(Expr::Cast {
230 expr: Box::new(Expr::FunctionCall {
231 name: "current_setting".into(),
232 args: vec![Expr::Literal(Value::String(session_var.into()))],
233 alias: None,
234 }),
235 target_type: "boolean".into(),
236 alias: None,
237 }),
238 op: BinaryOp::Eq,
239 right: Box::new(Expr::Literal(Value::Bool(true))),
240 alias: None,
241 }
242}
243
244pub fn or(left: Expr, right: Expr) -> Expr {
248 use crate::ast::BinaryOp;
249
250 Expr::Binary {
251 left: Box::new(left),
252 op: BinaryOp::Or,
253 right: Box::new(right),
254 alias: None,
255 }
256}
257
258pub fn and(left: Expr, right: Expr) -> Expr {
260 use crate::ast::BinaryOp;
261
262 Expr::Binary {
263 left: Box::new(left),
264 op: BinaryOp::And,
265 right: Box::new(right),
266 alias: None,
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::ast::BinaryOp;
274
275 #[test]
276 fn test_policy_builder() {
277 let policy = RlsPolicy::create("orders_isolation", "orders")
278 .for_all()
279 .using(tenant_check("operator_id", "app.current_operator_id", "uuid"))
280 .with_check(tenant_check("operator_id", "app.current_operator_id", "uuid"));
281
282 assert_eq!(policy.name, "orders_isolation");
283 assert_eq!(policy.table, "orders");
284 assert_eq!(policy.target, PolicyTarget::All);
285 assert!(policy.using.is_some());
286 assert!(policy.with_check.is_some());
287 }
288
289 #[test]
290 fn test_policy_restrictive() {
291 let policy = RlsPolicy::create("admin_only", "secrets")
292 .for_select()
293 .restrictive()
294 .to_role("app_user");
295
296 assert_eq!(policy.target, PolicyTarget::Select);
297 assert_eq!(policy.permissiveness, PolicyPermissiveness::Restrictive);
298 assert_eq!(policy.role.as_deref(), Some("app_user"));
299 }
300
301 #[test]
302 fn test_tenant_check_helper() {
303 let expr = tenant_check("operator_id", "app.current_operator_id", "uuid");
304
305 match &expr {
306 Expr::Binary { left, op, right, .. } => {
307 assert_eq!(*op, BinaryOp::Eq);
308 match left.as_ref() {
309 Expr::Named(n) => assert_eq!(n, "operator_id"),
310 _ => panic!("Expected Named"),
311 }
312 match right.as_ref() {
313 Expr::Cast { expr, target_type, .. } => {
314 assert_eq!(target_type, "uuid");
315 match expr.as_ref() {
316 Expr::FunctionCall { name, args, .. } => {
317 assert_eq!(name, "current_setting");
318 assert_eq!(args.len(), 1);
319 }
320 _ => panic!("Expected FunctionCall"),
321 }
322 }
323 _ => panic!("Expected Cast"),
324 }
325 }
326 _ => panic!("Expected Binary"),
327 }
328 }
329
330 #[test]
331 fn test_super_admin_bypass() {
332 let expr = or(
333 tenant_check("operator_id", "app.current_operator_id", "uuid"),
334 session_bool_check("app.is_super_admin"),
335 );
336
337 match &expr {
338 Expr::Binary { op, .. } => assert_eq!(*op, BinaryOp::Or),
339 _ => panic!("Expected Binary OR"),
340 }
341 }
342
343 #[test]
344 fn test_and_combinator() {
345 let expr = and(
346 tenant_check("operator_id", "app.current_operator_id", "uuid"),
347 tenant_check("agent_id", "app.current_agent_id", "uuid"),
348 );
349
350 match &expr {
351 Expr::Binary { op, .. } => assert_eq!(*op, BinaryOp::And),
352 _ => panic!("Expected Binary AND"),
353 }
354 }
355
356 #[test]
357 fn test_policy_target_display() {
358 assert_eq!(PolicyTarget::All.to_string(), "ALL");
359 assert_eq!(PolicyTarget::Select.to_string(), "SELECT");
360 assert_eq!(PolicyTarget::Insert.to_string(), "INSERT");
361 assert_eq!(PolicyTarget::Update.to_string(), "UPDATE");
362 assert_eq!(PolicyTarget::Delete.to_string(), "DELETE");
363 }
364}