prax_schema/ast/
policy.rs

1//! Row-Level Security (RLS) policy definitions for the Prax schema AST.
2//!
3//! Policies enable fine-grained access control at the row level.
4//! They are evaluated for each row that a query accesses and determine whether
5//! the row should be visible or modifiable based on the policy expression.
6//!
7//! ## Supported Databases
8//!
9//! - **PostgreSQL**: Native RLS with CREATE POLICY
10//! - **SQL Server (MSSQL)**: Security Policies with predicate functions
11//!
12//! ## PostgreSQL RLS
13//!
14//! PostgreSQL uses `CREATE POLICY` statements with USING (filter) and WITH CHECK
15//! (block) expressions evaluated inline.
16//!
17//! ## SQL Server RLS
18//!
19//! SQL Server requires:
20//! 1. A schema-bound inline table-valued function (predicate function)
21//! 2. A security policy binding the function to the table
22//!
23//! The predicate function returns 1 for rows that should be accessible.
24
25use serde::{Deserialize, Serialize};
26use smol_str::SmolStr;
27
28use super::{Documentation, Ident, Span};
29
30/// A Row-Level Security (RLS) policy definition.
31///
32/// Policies provide fine-grained access control at the row level.
33/// They are applied to tables and evaluated for each row operation.
34///
35/// # Example Schema Syntax
36///
37/// ```text
38/// policy UserReadOwnData on User {
39///     for     SELECT
40///     to      authenticated
41///     using   "id = current_user_id()"
42/// }
43///
44/// policy UserModifyOwnData on User {
45///     for     [INSERT, UPDATE, DELETE]
46///     to      authenticated
47///     using   "id = current_user_id()"
48///     check   "id = current_user_id()"
49/// }
50/// ```
51///
52/// # Database Support
53///
54/// - PostgreSQL: Full support via CREATE POLICY
55/// - SQL Server: Supported via Security Policies with predicate functions
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub struct Policy {
58    /// Policy name (must be unique per table).
59    pub name: Ident,
60    /// The model/table this policy applies to.
61    pub table: Ident,
62    /// Policy type: PERMISSIVE (default) or RESTRICTIVE.
63    pub policy_type: PolicyType,
64    /// Commands this policy applies to (SELECT, INSERT, UPDATE, DELETE, or ALL).
65    pub commands: Vec<PolicyCommand>,
66    /// Roles this policy applies to (default: PUBLIC).
67    pub roles: Vec<SmolStr>,
68    /// USING expression - evaluated for existing rows (SELECT, UPDATE, DELETE).
69    /// Should return boolean. Row is visible if expression returns true.
70    /// In MSSQL, this becomes the FILTER PREDICATE.
71    pub using_expr: Option<String>,
72    /// WITH CHECK expression - evaluated for new rows (INSERT, UPDATE).
73    /// Should return boolean. Row can be inserted/updated if expression returns true.
74    /// In MSSQL, this becomes BLOCK PREDICATE(s).
75    pub check_expr: Option<String>,
76    /// MSSQL-specific: Schema for the predicate function (default: "Security").
77    pub mssql_schema: Option<SmolStr>,
78    /// MSSQL-specific: Block operations to apply (default: all applicable).
79    pub mssql_block_operations: Vec<MssqlBlockOperation>,
80    /// Documentation comment.
81    pub documentation: Option<Documentation>,
82    /// Source location.
83    pub span: Span,
84}
85
86impl Policy {
87    /// Create a new policy.
88    pub fn new(name: Ident, table: Ident, span: Span) -> Self {
89        Self {
90            name,
91            table,
92            policy_type: PolicyType::Permissive,
93            commands: vec![PolicyCommand::All],
94            roles: vec![],
95            using_expr: None,
96            check_expr: None,
97            mssql_schema: None,
98            mssql_block_operations: vec![],
99            documentation: None,
100            span,
101        }
102    }
103
104    /// Get the policy name as a string.
105    pub fn name(&self) -> &str {
106        self.name.as_str()
107    }
108
109    /// Get the table name as a string.
110    pub fn table(&self) -> &str {
111        self.table.as_str()
112    }
113
114    /// Set the policy type.
115    pub fn with_type(mut self, policy_type: PolicyType) -> Self {
116        self.policy_type = policy_type;
117        self
118    }
119
120    /// Set the commands this policy applies to.
121    pub fn with_commands(mut self, commands: Vec<PolicyCommand>) -> Self {
122        self.commands = commands;
123        self
124    }
125
126    /// Add a command this policy applies to.
127    pub fn add_command(&mut self, command: PolicyCommand) {
128        self.commands.push(command);
129    }
130
131    /// Set the roles this policy applies to.
132    pub fn with_roles(mut self, roles: Vec<SmolStr>) -> Self {
133        self.roles = roles;
134        self
135    }
136
137    /// Add a role this policy applies to.
138    pub fn add_role(&mut self, role: impl Into<SmolStr>) {
139        self.roles.push(role.into());
140    }
141
142    /// Set the USING expression.
143    pub fn with_using(mut self, expr: impl Into<String>) -> Self {
144        self.using_expr = Some(expr.into());
145        self
146    }
147
148    /// Set the WITH CHECK expression.
149    pub fn with_check(mut self, expr: impl Into<String>) -> Self {
150        self.check_expr = Some(expr.into());
151        self
152    }
153
154    /// Set documentation.
155    pub fn with_documentation(mut self, doc: Documentation) -> Self {
156        self.documentation = Some(doc);
157        self
158    }
159
160    /// Set the MSSQL schema for the predicate function.
161    pub fn with_mssql_schema(mut self, schema: impl Into<SmolStr>) -> Self {
162        self.mssql_schema = Some(schema.into());
163        self
164    }
165
166    /// Set the MSSQL block operations.
167    pub fn with_mssql_block_operations(mut self, operations: Vec<MssqlBlockOperation>) -> Self {
168        self.mssql_block_operations = operations;
169        self
170    }
171
172    /// Add an MSSQL block operation.
173    pub fn add_mssql_block_operation(&mut self, operation: MssqlBlockOperation) {
174        self.mssql_block_operations.push(operation);
175    }
176
177    /// Check if this policy applies to a specific command.
178    pub fn applies_to(&self, command: PolicyCommand) -> bool {
179        self.commands.contains(&PolicyCommand::All) || self.commands.contains(&command)
180    }
181
182    /// Check if this policy is restrictive.
183    pub fn is_restrictive(&self) -> bool {
184        self.policy_type == PolicyType::Restrictive
185    }
186
187    /// Check if this policy is permissive.
188    pub fn is_permissive(&self) -> bool {
189        self.policy_type == PolicyType::Permissive
190    }
191
192    /// Get the effective roles (PUBLIC if none specified).
193    pub fn effective_roles(&self) -> Vec<&str> {
194        if self.roles.is_empty() {
195            vec!["PUBLIC"]
196        } else {
197            self.roles.iter().map(|r| r.as_str()).collect()
198        }
199    }
200
201    /// Get the MSSQL schema (default: "Security").
202    pub fn mssql_schema(&self) -> &str {
203        self.mssql_schema
204            .as_ref()
205            .map(|s| s.as_str())
206            .unwrap_or("Security")
207    }
208
209    /// Get the predicate function name for MSSQL.
210    pub fn mssql_predicate_function_name(&self) -> String {
211        format!("fn_{}_predicate", self.name())
212    }
213
214    /// Generate the PostgreSQL CREATE POLICY statement.
215    pub fn to_sql(&self, table_name: &str) -> String {
216        self.to_postgres_sql(table_name)
217    }
218
219    /// Generate the PostgreSQL CREATE POLICY statement.
220    pub fn to_postgres_sql(&self, table_name: &str) -> String {
221        let mut sql = format!("CREATE POLICY {} ON {}", self.name(), table_name);
222
223        // AS PERMISSIVE/RESTRICTIVE
224        match self.policy_type {
225            PolicyType::Permissive => {} // Default, no need to specify
226            PolicyType::Restrictive => sql.push_str(" AS RESTRICTIVE"),
227        }
228
229        // FOR command
230        if !self.commands.is_empty() && !self.commands.contains(&PolicyCommand::All) {
231            let cmds: Vec<&str> = self.commands.iter().map(|c| c.as_str()).collect();
232            // PostgreSQL only allows a single command, use first one
233            sql.push_str(&format!(" FOR {}", cmds[0]));
234        }
235
236        // TO roles
237        let roles = self.effective_roles();
238        sql.push_str(&format!(" TO {}", roles.join(", ")));
239
240        // USING expression
241        if let Some(ref using) = self.using_expr {
242            sql.push_str(&format!(" USING ({})", using));
243        }
244
245        // WITH CHECK expression
246        if let Some(ref check) = self.check_expr {
247            sql.push_str(&format!(" WITH CHECK ({})", check));
248        }
249
250        sql
251    }
252
253    /// Generate SQL Server (MSSQL) security policy statements.
254    ///
255    /// Returns a tuple of:
256    /// 1. CREATE FUNCTION statement for the predicate function
257    /// 2. CREATE SECURITY POLICY statement
258    ///
259    /// # Arguments
260    ///
261    /// * `table_name` - The fully qualified table name (e.g., "dbo.Users")
262    /// * `predicate_column` - The column name to use in the predicate (e.g., "UserId")
263    ///
264    /// # Example
265    ///
266    /// ```ignore
267    /// let (func_sql, policy_sql) = policy.to_mssql_sql("dbo.Users", "UserId");
268    /// ```
269    pub fn to_mssql_sql(&self, table_name: &str, predicate_column: &str) -> MssqlPolicyStatements {
270        let schema = self.mssql_schema();
271        let func_name = self.mssql_predicate_function_name();
272
273        // Generate the predicate function
274        let filter_expr = self
275            .using_expr
276            .as_deref()
277            .unwrap_or("1 = 1")
278            .replace(
279                "current_user_id()",
280                "CAST(SESSION_CONTEXT(N'UserId') AS INT)",
281            )
282            .replace("auth.uid()", "CAST(SESSION_CONTEXT(N'UserId') AS INT)")
283            .replace(
284                "current_setting('app.current_org')",
285                "SESSION_CONTEXT(N'OrgId')",
286            );
287
288        let function_sql = format!(
289            r#"CREATE FUNCTION {schema}.{func_name}(@{predicate_column} AS INT)
290    RETURNS TABLE
291WITH SCHEMABINDING
292AS
293    RETURN SELECT 1 AS fn_securitypredicate_result
294    WHERE {filter_expr}"#,
295            schema = schema,
296            func_name = func_name,
297            predicate_column = predicate_column,
298            filter_expr = filter_expr
299        );
300
301        // Generate the security policy
302        let mut policy_sql = format!(
303            "CREATE SECURITY POLICY {schema}.{policy_name}\n",
304            schema = schema,
305            policy_name = self.name()
306        );
307
308        // Add FILTER PREDICATE if we have a using expression
309        if self.using_expr.is_some() {
310            policy_sql.push_str(&format!(
311                "ADD FILTER PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name}",
312                schema = schema,
313                func_name = func_name,
314                predicate_column = predicate_column,
315                table_name = table_name
316            ));
317        }
318
319        // Add BLOCK PREDICATE(s) if we have a check expression
320        if self.check_expr.is_some() {
321            let block_ops = if self.mssql_block_operations.is_empty() {
322                // Default block operations based on commands
323                self.default_mssql_block_operations()
324            } else {
325                self.mssql_block_operations.clone()
326            };
327
328            for (i, op) in block_ops.iter().enumerate() {
329                if i > 0 || self.using_expr.is_some() {
330                    policy_sql.push_str(",\n");
331                }
332                policy_sql.push_str(&format!(
333                    "ADD BLOCK PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name} {op}",
334                    schema = schema,
335                    func_name = func_name,
336                    predicate_column = predicate_column,
337                    table_name = table_name,
338                    op = op.as_str()
339                ));
340            }
341        }
342
343        policy_sql.push_str("\nWITH (STATE = ON)");
344
345        MssqlPolicyStatements {
346            schema_sql: format!("CREATE SCHEMA {schema}"),
347            function_sql,
348            policy_sql,
349        }
350    }
351
352    /// Get default MSSQL block operations based on the policy commands.
353    fn default_mssql_block_operations(&self) -> Vec<MssqlBlockOperation> {
354        let mut ops = vec![];
355
356        if self.applies_to(PolicyCommand::Insert) {
357            ops.push(MssqlBlockOperation::AfterInsert);
358        }
359        if self.applies_to(PolicyCommand::Update) {
360            ops.push(MssqlBlockOperation::AfterUpdate);
361            ops.push(MssqlBlockOperation::BeforeUpdate);
362        }
363        if self.applies_to(PolicyCommand::Delete) {
364            ops.push(MssqlBlockOperation::BeforeDelete);
365        }
366
367        ops
368    }
369}
370
371/// PostgreSQL policy type.
372#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
373pub enum PolicyType {
374    /// Permissive policies are combined with OR.
375    /// At least one permissive policy must allow access.
376    #[default]
377    Permissive,
378    /// Restrictive policies are combined with AND.
379    /// All restrictive policies must allow access.
380    Restrictive,
381}
382
383impl PolicyType {
384    /// Parse a policy type from a string.
385    #[allow(clippy::should_implement_trait)]
386    pub fn from_str(s: &str) -> Option<Self> {
387        match s.to_uppercase().as_str() {
388            "PERMISSIVE" => Some(Self::Permissive),
389            "RESTRICTIVE" => Some(Self::Restrictive),
390            _ => None,
391        }
392    }
393
394    /// Get the SQL keyword for this policy type.
395    pub fn as_str(&self) -> &'static str {
396        match self {
397            Self::Permissive => "PERMISSIVE",
398            Self::Restrictive => "RESTRICTIVE",
399        }
400    }
401}
402
403impl std::fmt::Display for PolicyType {
404    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        write!(f, "{}", self.as_str())
406    }
407}
408
409/// PostgreSQL policy command type.
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
411pub enum PolicyCommand {
412    /// Policy applies to all commands (SELECT, INSERT, UPDATE, DELETE).
413    All,
414    /// Policy applies to SELECT queries.
415    Select,
416    /// Policy applies to INSERT statements.
417    Insert,
418    /// Policy applies to UPDATE statements.
419    Update,
420    /// Policy applies to DELETE statements.
421    Delete,
422}
423
424impl PolicyCommand {
425    /// Parse a policy command from a string.
426    #[allow(clippy::should_implement_trait)]
427    pub fn from_str(s: &str) -> Option<Self> {
428        match s.to_uppercase().as_str() {
429            "ALL" => Some(Self::All),
430            "SELECT" => Some(Self::Select),
431            "INSERT" => Some(Self::Insert),
432            "UPDATE" => Some(Self::Update),
433            "DELETE" => Some(Self::Delete),
434            _ => None,
435        }
436    }
437
438    /// Get the SQL keyword for this command.
439    pub fn as_str(&self) -> &'static str {
440        match self {
441            Self::All => "ALL",
442            Self::Select => "SELECT",
443            Self::Insert => "INSERT",
444            Self::Update => "UPDATE",
445            Self::Delete => "DELETE",
446        }
447    }
448
449    /// Check if this command requires a USING expression.
450    pub fn requires_using(&self) -> bool {
451        matches!(self, Self::All | Self::Select | Self::Update | Self::Delete)
452    }
453
454    /// Check if this command requires a WITH CHECK expression.
455    pub fn requires_check(&self) -> bool {
456        matches!(self, Self::All | Self::Insert | Self::Update)
457    }
458}
459
460impl std::fmt::Display for PolicyCommand {
461    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462        write!(f, "{}", self.as_str())
463    }
464}
465
466/// MSSQL-specific block operation types.
467///
468/// SQL Server's BLOCK PREDICATE can be applied at different points
469/// in the data modification lifecycle.
470#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
471pub enum MssqlBlockOperation {
472    /// Block predicate evaluated after INSERT.
473    /// Prevents inserting rows that don't satisfy the predicate.
474    AfterInsert,
475    /// Block predicate evaluated after UPDATE.
476    /// Prevents updating rows to values that don't satisfy the predicate.
477    AfterUpdate,
478    /// Block predicate evaluated before UPDATE.
479    /// Prevents updating rows that currently don't satisfy the predicate.
480    BeforeUpdate,
481    /// Block predicate evaluated before DELETE.
482    /// Prevents deleting rows that don't satisfy the predicate.
483    BeforeDelete,
484}
485
486impl MssqlBlockOperation {
487    /// Parse a block operation from a string.
488    #[allow(clippy::should_implement_trait)]
489    pub fn from_str(s: &str) -> Option<Self> {
490        match s.to_uppercase().replace([' ', '_'], "").as_str() {
491            "AFTERINSERT" => Some(Self::AfterInsert),
492            "AFTERUPDATE" => Some(Self::AfterUpdate),
493            "BEFOREUPDATE" => Some(Self::BeforeUpdate),
494            "BEFOREDELETE" => Some(Self::BeforeDelete),
495            _ => None,
496        }
497    }
498
499    /// Get the SQL clause for this block operation.
500    pub fn as_str(&self) -> &'static str {
501        match self {
502            Self::AfterInsert => "AFTER INSERT",
503            Self::AfterUpdate => "AFTER UPDATE",
504            Self::BeforeUpdate => "BEFORE UPDATE",
505            Self::BeforeDelete => "BEFORE DELETE",
506        }
507    }
508}
509
510impl std::fmt::Display for MssqlBlockOperation {
511    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512        write!(f, "{}", self.as_str())
513    }
514}
515
516/// SQL statements generated for MSSQL security policies.
517#[derive(Debug, Clone, PartialEq)]
518pub struct MssqlPolicyStatements {
519    /// CREATE SCHEMA statement (if the schema doesn't exist).
520    pub schema_sql: String,
521    /// CREATE FUNCTION statement for the predicate function.
522    pub function_sql: String,
523    /// CREATE SECURITY POLICY statement.
524    pub policy_sql: String,
525}
526
527impl MssqlPolicyStatements {
528    /// Get all SQL statements in execution order.
529    pub fn all_statements(&self) -> Vec<&str> {
530        vec![&self.schema_sql, &self.function_sql, &self.policy_sql]
531    }
532
533    /// Get all SQL as a single string with separators.
534    pub fn to_sql(&self) -> String {
535        format!(
536            "{schema_sql};\nGO\n\n{function_sql};\nGO\n\n{policy_sql};",
537            schema_sql = self.schema_sql,
538            function_sql = self.function_sql,
539            policy_sql = self.policy_sql
540        )
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    fn make_span() -> Span {
549        Span::new(0, 10)
550    }
551
552    fn make_ident(name: &str) -> Ident {
553        Ident::new(name, make_span())
554    }
555
556    // ==================== Policy Tests ====================
557
558    #[test]
559    fn test_policy_new() {
560        let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span());
561
562        assert_eq!(policy.name(), "read_own");
563        assert_eq!(policy.table(), "User");
564        assert_eq!(policy.policy_type, PolicyType::Permissive);
565        assert_eq!(policy.commands, vec![PolicyCommand::All]);
566        assert!(policy.roles.is_empty());
567        assert!(policy.using_expr.is_none());
568        assert!(policy.check_expr.is_none());
569        assert!(policy.documentation.is_none());
570    }
571
572    #[test]
573    fn test_policy_with_type() {
574        let policy = Policy::new(make_ident("strict"), make_ident("User"), make_span())
575            .with_type(PolicyType::Restrictive);
576
577        assert!(policy.is_restrictive());
578        assert!(!policy.is_permissive());
579    }
580
581    #[test]
582    fn test_policy_with_commands() {
583        let policy = Policy::new(make_ident("read"), make_ident("User"), make_span())
584            .with_commands(vec![PolicyCommand::Select]);
585
586        assert!(policy.applies_to(PolicyCommand::Select));
587        assert!(!policy.applies_to(PolicyCommand::Insert));
588        assert!(!policy.applies_to(PolicyCommand::Update));
589        assert!(!policy.applies_to(PolicyCommand::Delete));
590    }
591
592    #[test]
593    fn test_policy_with_multiple_commands() {
594        let policy = Policy::new(make_ident("read_update"), make_ident("User"), make_span())
595            .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update]);
596
597        assert!(policy.applies_to(PolicyCommand::Select));
598        assert!(policy.applies_to(PolicyCommand::Update));
599        assert!(!policy.applies_to(PolicyCommand::Insert));
600        assert!(!policy.applies_to(PolicyCommand::Delete));
601    }
602
603    #[test]
604    fn test_policy_all_command_applies_to_all() {
605        let policy = Policy::new(make_ident("all"), make_ident("User"), make_span())
606            .with_commands(vec![PolicyCommand::All]);
607
608        assert!(policy.applies_to(PolicyCommand::Select));
609        assert!(policy.applies_to(PolicyCommand::Insert));
610        assert!(policy.applies_to(PolicyCommand::Update));
611        assert!(policy.applies_to(PolicyCommand::Delete));
612        assert!(policy.applies_to(PolicyCommand::All));
613    }
614
615    #[test]
616    fn test_policy_add_command() {
617        let mut policy =
618            Policy::new(make_ident("test"), make_ident("User"), make_span()).with_commands(vec![]);
619
620        policy.add_command(PolicyCommand::Select);
621        policy.add_command(PolicyCommand::Update);
622
623        assert_eq!(policy.commands.len(), 2);
624        assert!(policy.applies_to(PolicyCommand::Select));
625        assert!(policy.applies_to(PolicyCommand::Update));
626    }
627
628    #[test]
629    fn test_policy_with_roles() {
630        let policy = Policy::new(make_ident("auth"), make_ident("User"), make_span())
631            .with_roles(vec!["authenticated".into(), "admin".into()]);
632
633        assert_eq!(policy.roles.len(), 2);
634        let roles = policy.effective_roles();
635        assert!(roles.contains(&"authenticated"));
636        assert!(roles.contains(&"admin"));
637    }
638
639    #[test]
640    fn test_policy_add_role() {
641        let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
642
643        policy.add_role("user");
644        policy.add_role("moderator");
645
646        assert_eq!(policy.roles.len(), 2);
647    }
648
649    #[test]
650    fn test_policy_effective_roles_default() {
651        let policy = Policy::new(make_ident("public"), make_ident("User"), make_span());
652
653        let roles = policy.effective_roles();
654        assert_eq!(roles, vec!["PUBLIC"]);
655    }
656
657    #[test]
658    fn test_policy_with_using() {
659        let policy = Policy::new(make_ident("own"), make_ident("User"), make_span())
660            .with_using("user_id = current_user_id()");
661
662        assert_eq!(
663            policy.using_expr.as_deref(),
664            Some("user_id = current_user_id()")
665        );
666    }
667
668    #[test]
669    fn test_policy_with_check() {
670        let policy = Policy::new(make_ident("insert"), make_ident("User"), make_span())
671            .with_check("user_id = current_user_id()");
672
673        assert_eq!(
674            policy.check_expr.as_deref(),
675            Some("user_id = current_user_id()")
676        );
677    }
678
679    #[test]
680    fn test_policy_with_documentation() {
681        let policy =
682            Policy::new(make_ident("doc"), make_ident("User"), make_span()).with_documentation(
683                Documentation::new("Users can only see their own data", make_span()),
684            );
685
686        assert!(policy.documentation.is_some());
687        assert_eq!(
688            policy.documentation.unwrap().text,
689            "Users can only see their own data"
690        );
691    }
692
693    #[test]
694    fn test_policy_to_sql_simple() {
695        let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span())
696            .with_commands(vec![PolicyCommand::Select])
697            .with_using("id = current_user_id()");
698
699        let sql = policy.to_sql("users");
700        assert!(sql.contains("CREATE POLICY read_own ON users"));
701        assert!(sql.contains("FOR SELECT"));
702        assert!(sql.contains("TO PUBLIC"));
703        assert!(sql.contains("USING (id = current_user_id())"));
704    }
705
706    #[test]
707    fn test_policy_to_sql_with_roles() {
708        let policy = Policy::new(make_ident("auth_read"), make_ident("User"), make_span())
709            .with_commands(vec![PolicyCommand::Select])
710            .with_roles(vec!["authenticated".into()])
711            .with_using("true");
712
713        let sql = policy.to_sql("users");
714        assert!(sql.contains("TO authenticated"));
715    }
716
717    #[test]
718    fn test_policy_to_sql_restrictive() {
719        let policy = Policy::new(make_ident("restrict"), make_ident("User"), make_span())
720            .with_type(PolicyType::Restrictive)
721            .with_using("org_id = current_org_id()");
722
723        let sql = policy.to_sql("users");
724        assert!(sql.contains("AS RESTRICTIVE"));
725    }
726
727    #[test]
728    fn test_policy_to_sql_with_check() {
729        let policy = Policy::new(make_ident("insert_own"), make_ident("User"), make_span())
730            .with_commands(vec![PolicyCommand::Insert])
731            .with_check("id = current_user_id()");
732
733        let sql = policy.to_sql("users");
734        assert!(sql.contains("FOR INSERT"));
735        assert!(sql.contains("WITH CHECK (id = current_user_id())"));
736    }
737
738    #[test]
739    fn test_policy_to_sql_both_expressions() {
740        let policy = Policy::new(make_ident("update_own"), make_ident("User"), make_span())
741            .with_commands(vec![PolicyCommand::Update])
742            .with_using("id = current_user_id()")
743            .with_check("id = current_user_id()");
744
745        let sql = policy.to_sql("users");
746        assert!(sql.contains("USING (id = current_user_id())"));
747        assert!(sql.contains("WITH CHECK (id = current_user_id())"));
748    }
749
750    #[test]
751    fn test_policy_equality() {
752        let policy1 = Policy::new(make_ident("test"), make_ident("User"), make_span());
753        let policy2 = Policy::new(make_ident("test"), make_ident("User"), make_span());
754
755        assert_eq!(policy1, policy2);
756    }
757
758    #[test]
759    fn test_policy_clone() {
760        let policy = Policy::new(make_ident("original"), make_ident("User"), make_span())
761            .with_using("id = 1");
762
763        let cloned = policy.clone();
764        assert_eq!(cloned.name(), "original");
765        assert_eq!(cloned.using_expr, Some("id = 1".to_string()));
766    }
767
768    // ==================== PolicyType Tests ====================
769
770    #[test]
771    fn test_policy_type_from_str() {
772        assert_eq!(
773            PolicyType::from_str("PERMISSIVE"),
774            Some(PolicyType::Permissive)
775        );
776        assert_eq!(
777            PolicyType::from_str("permissive"),
778            Some(PolicyType::Permissive)
779        );
780        assert_eq!(
781            PolicyType::from_str("Permissive"),
782            Some(PolicyType::Permissive)
783        );
784        assert_eq!(
785            PolicyType::from_str("RESTRICTIVE"),
786            Some(PolicyType::Restrictive)
787        );
788        assert_eq!(
789            PolicyType::from_str("restrictive"),
790            Some(PolicyType::Restrictive)
791        );
792        assert_eq!(PolicyType::from_str("invalid"), None);
793    }
794
795    #[test]
796    fn test_policy_type_as_str() {
797        assert_eq!(PolicyType::Permissive.as_str(), "PERMISSIVE");
798        assert_eq!(PolicyType::Restrictive.as_str(), "RESTRICTIVE");
799    }
800
801    #[test]
802    fn test_policy_type_display() {
803        assert_eq!(format!("{}", PolicyType::Permissive), "PERMISSIVE");
804        assert_eq!(format!("{}", PolicyType::Restrictive), "RESTRICTIVE");
805    }
806
807    #[test]
808    fn test_policy_type_default() {
809        let policy_type: PolicyType = Default::default();
810        assert_eq!(policy_type, PolicyType::Permissive);
811    }
812
813    #[test]
814    fn test_policy_type_equality() {
815        assert_eq!(PolicyType::Permissive, PolicyType::Permissive);
816        assert_eq!(PolicyType::Restrictive, PolicyType::Restrictive);
817        assert_ne!(PolicyType::Permissive, PolicyType::Restrictive);
818    }
819
820    // ==================== PolicyCommand Tests ====================
821
822    #[test]
823    fn test_policy_command_from_str() {
824        assert_eq!(PolicyCommand::from_str("ALL"), Some(PolicyCommand::All));
825        assert_eq!(PolicyCommand::from_str("all"), Some(PolicyCommand::All));
826        assert_eq!(
827            PolicyCommand::from_str("SELECT"),
828            Some(PolicyCommand::Select)
829        );
830        assert_eq!(
831            PolicyCommand::from_str("select"),
832            Some(PolicyCommand::Select)
833        );
834        assert_eq!(
835            PolicyCommand::from_str("INSERT"),
836            Some(PolicyCommand::Insert)
837        );
838        assert_eq!(
839            PolicyCommand::from_str("UPDATE"),
840            Some(PolicyCommand::Update)
841        );
842        assert_eq!(
843            PolicyCommand::from_str("DELETE"),
844            Some(PolicyCommand::Delete)
845        );
846        assert_eq!(PolicyCommand::from_str("invalid"), None);
847    }
848
849    #[test]
850    fn test_policy_command_as_str() {
851        assert_eq!(PolicyCommand::All.as_str(), "ALL");
852        assert_eq!(PolicyCommand::Select.as_str(), "SELECT");
853        assert_eq!(PolicyCommand::Insert.as_str(), "INSERT");
854        assert_eq!(PolicyCommand::Update.as_str(), "UPDATE");
855        assert_eq!(PolicyCommand::Delete.as_str(), "DELETE");
856    }
857
858    #[test]
859    fn test_policy_command_display() {
860        assert_eq!(format!("{}", PolicyCommand::All), "ALL");
861        assert_eq!(format!("{}", PolicyCommand::Select), "SELECT");
862        assert_eq!(format!("{}", PolicyCommand::Insert), "INSERT");
863        assert_eq!(format!("{}", PolicyCommand::Update), "UPDATE");
864        assert_eq!(format!("{}", PolicyCommand::Delete), "DELETE");
865    }
866
867    #[test]
868    fn test_policy_command_requires_using() {
869        assert!(PolicyCommand::All.requires_using());
870        assert!(PolicyCommand::Select.requires_using());
871        assert!(PolicyCommand::Update.requires_using());
872        assert!(PolicyCommand::Delete.requires_using());
873        assert!(!PolicyCommand::Insert.requires_using());
874    }
875
876    #[test]
877    fn test_policy_command_requires_check() {
878        assert!(PolicyCommand::All.requires_check());
879        assert!(PolicyCommand::Insert.requires_check());
880        assert!(PolicyCommand::Update.requires_check());
881        assert!(!PolicyCommand::Select.requires_check());
882        assert!(!PolicyCommand::Delete.requires_check());
883    }
884
885    #[test]
886    fn test_policy_command_equality() {
887        assert_eq!(PolicyCommand::Select, PolicyCommand::Select);
888        assert_ne!(PolicyCommand::Select, PolicyCommand::Insert);
889    }
890
891    // ==================== Full Policy Scenario Tests ====================
892
893    #[test]
894    fn test_policy_rls_scenario_user_isolation() {
895        // Scenario: Users can only see and modify their own records
896        let policy = Policy::new(
897            make_ident("user_isolation"),
898            make_ident("User"),
899            make_span(),
900        )
901        .with_type(PolicyType::Permissive)
902        .with_commands(vec![PolicyCommand::All])
903        .with_roles(vec!["authenticated".into()])
904        .with_using("id = auth.uid()")
905        .with_check("id = auth.uid()");
906
907        assert!(policy.is_permissive());
908        assert!(policy.applies_to(PolicyCommand::Select));
909        assert!(policy.applies_to(PolicyCommand::Insert));
910        assert!(policy.applies_to(PolicyCommand::Update));
911        assert!(policy.applies_to(PolicyCommand::Delete));
912
913        let sql = policy.to_sql("users");
914        assert!(sql.contains("auth.uid()"));
915    }
916
917    #[test]
918    fn test_policy_rls_scenario_org_based() {
919        // Scenario: Users can only access records in their organization
920        let policy = Policy::new(
921            make_ident("org_access"),
922            make_ident("Document"),
923            make_span(),
924        )
925        .with_type(PolicyType::Restrictive)
926        .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update])
927        .with_using("org_id = current_setting('app.current_org')::uuid");
928
929        assert!(policy.is_restrictive());
930        assert!(policy.applies_to(PolicyCommand::Select));
931        assert!(policy.applies_to(PolicyCommand::Update));
932        assert!(!policy.applies_to(PolicyCommand::Delete));
933
934        let sql = policy.to_sql("documents");
935        assert!(sql.contains("AS RESTRICTIVE"));
936        assert!(sql.contains("current_setting"));
937    }
938
939    #[test]
940    fn test_policy_rls_scenario_public_read() {
941        // Scenario: Anyone can read, only owner can modify
942        let read_policy = Policy::new(make_ident("public_read"), make_ident("Post"), make_span())
943            .with_commands(vec![PolicyCommand::Select])
944            .with_using("published = true OR author_id = current_user_id()");
945
946        let write_policy = Policy::new(make_ident("owner_write"), make_ident("Post"), make_span())
947            .with_commands(vec![PolicyCommand::Update, PolicyCommand::Delete])
948            .with_roles(vec!["authenticated".into()])
949            .with_using("author_id = current_user_id()");
950
951        assert_eq!(read_policy.effective_roles(), vec!["PUBLIC"]);
952        assert!(write_policy.effective_roles().contains(&"authenticated"));
953    }
954
955    // ==================== MSSQL Block Operation Tests ====================
956
957    #[test]
958    fn test_mssql_block_operation_from_str() {
959        assert_eq!(
960            MssqlBlockOperation::from_str("AFTER INSERT"),
961            Some(MssqlBlockOperation::AfterInsert)
962        );
963        assert_eq!(
964            MssqlBlockOperation::from_str("after_insert"),
965            Some(MssqlBlockOperation::AfterInsert)
966        );
967        assert_eq!(
968            MssqlBlockOperation::from_str("AFTERINSERT"),
969            Some(MssqlBlockOperation::AfterInsert)
970        );
971        assert_eq!(
972            MssqlBlockOperation::from_str("AFTER UPDATE"),
973            Some(MssqlBlockOperation::AfterUpdate)
974        );
975        assert_eq!(
976            MssqlBlockOperation::from_str("BEFORE UPDATE"),
977            Some(MssqlBlockOperation::BeforeUpdate)
978        );
979        assert_eq!(
980            MssqlBlockOperation::from_str("BEFORE DELETE"),
981            Some(MssqlBlockOperation::BeforeDelete)
982        );
983        assert_eq!(MssqlBlockOperation::from_str("invalid"), None);
984    }
985
986    #[test]
987    fn test_mssql_block_operation_as_str() {
988        assert_eq!(MssqlBlockOperation::AfterInsert.as_str(), "AFTER INSERT");
989        assert_eq!(MssqlBlockOperation::AfterUpdate.as_str(), "AFTER UPDATE");
990        assert_eq!(MssqlBlockOperation::BeforeUpdate.as_str(), "BEFORE UPDATE");
991        assert_eq!(MssqlBlockOperation::BeforeDelete.as_str(), "BEFORE DELETE");
992    }
993
994    #[test]
995    fn test_mssql_block_operation_display() {
996        assert_eq!(
997            format!("{}", MssqlBlockOperation::AfterInsert),
998            "AFTER INSERT"
999        );
1000        assert_eq!(
1001            format!("{}", MssqlBlockOperation::BeforeDelete),
1002            "BEFORE DELETE"
1003        );
1004    }
1005
1006    // ==================== MSSQL Policy Tests ====================
1007
1008    #[test]
1009    fn test_policy_mssql_schema_default() {
1010        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1011        assert_eq!(policy.mssql_schema(), "Security");
1012    }
1013
1014    #[test]
1015    fn test_policy_with_mssql_schema() {
1016        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1017            .with_mssql_schema("RLS");
1018        assert_eq!(policy.mssql_schema(), "RLS");
1019    }
1020
1021    #[test]
1022    fn test_policy_mssql_predicate_function_name() {
1023        let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span());
1024        assert_eq!(
1025            policy.mssql_predicate_function_name(),
1026            "fn_user_filter_predicate"
1027        );
1028    }
1029
1030    #[test]
1031    fn test_policy_with_mssql_block_operations() {
1032        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1033            .with_mssql_block_operations(vec![
1034                MssqlBlockOperation::AfterInsert,
1035                MssqlBlockOperation::AfterUpdate,
1036            ]);
1037
1038        assert_eq!(policy.mssql_block_operations.len(), 2);
1039    }
1040
1041    #[test]
1042    fn test_policy_add_mssql_block_operation() {
1043        let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1044
1045        policy.add_mssql_block_operation(MssqlBlockOperation::AfterInsert);
1046        policy.add_mssql_block_operation(MssqlBlockOperation::BeforeDelete);
1047
1048        assert_eq!(policy.mssql_block_operations.len(), 2);
1049    }
1050
1051    #[test]
1052    fn test_policy_to_mssql_sql_simple() {
1053        let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span())
1054            .with_commands(vec![PolicyCommand::Select])
1055            .with_using("UserId = @UserId");
1056
1057        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1058
1059        // Check schema creation
1060        assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1061
1062        // Check function creation
1063        assert!(
1064            mssql
1065                .function_sql
1066                .contains("CREATE FUNCTION Security.fn_user_filter_predicate")
1067        );
1068        assert!(mssql.function_sql.contains("@UserId AS INT"));
1069        assert!(mssql.function_sql.contains("WITH SCHEMABINDING"));
1070        assert!(mssql.function_sql.contains("RETURNS TABLE"));
1071
1072        // Check policy creation
1073        assert!(
1074            mssql
1075                .policy_sql
1076                .contains("CREATE SECURITY POLICY Security.user_filter")
1077        );
1078        assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1079        assert!(mssql.policy_sql.contains("ON dbo.Users"));
1080        assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1081    }
1082
1083    #[test]
1084    fn test_policy_to_mssql_sql_with_check() {
1085        let policy = Policy::new(make_ident("user_insert"), make_ident("User"), make_span())
1086            .with_commands(vec![PolicyCommand::Insert])
1087            .with_check("UserId = @UserId");
1088
1089        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1090
1091        assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1092        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1093    }
1094
1095    #[test]
1096    fn test_policy_to_mssql_sql_with_both() {
1097        let policy = Policy::new(make_ident("user_all"), make_ident("User"), make_span())
1098            .with_commands(vec![PolicyCommand::All])
1099            .with_using("UserId = @UserId")
1100            .with_check("UserId = @UserId");
1101
1102        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1103
1104        assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1105        assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1106        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1107        assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1108        assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1109        assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1110    }
1111
1112    #[test]
1113    fn test_policy_to_mssql_sql_custom_schema() {
1114        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1115            .with_mssql_schema("RLS")
1116            .with_using("UserId = @UserId");
1117
1118        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1119
1120        assert!(mssql.schema_sql.contains("CREATE SCHEMA RLS"));
1121        assert!(mssql.function_sql.contains("RLS.fn_test_predicate"));
1122        assert!(mssql.policy_sql.contains("RLS.test"));
1123    }
1124
1125    #[test]
1126    fn test_policy_to_mssql_sql_translates_postgres_functions() {
1127        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1128            .with_using("id = current_user_id()");
1129
1130        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1131
1132        assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1133        assert!(!mssql.function_sql.contains("current_user_id"));
1134    }
1135
1136    #[test]
1137    fn test_policy_to_mssql_sql_translates_auth_uid() {
1138        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1139            .with_using("id = auth.uid()");
1140
1141        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1142
1143        assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1144        assert!(!mssql.function_sql.contains("auth.uid"));
1145    }
1146
1147    #[test]
1148    fn test_mssql_policy_statements_all_statements() {
1149        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1150            .with_using("UserId = @UserId");
1151
1152        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1153        let statements = mssql.all_statements();
1154
1155        assert_eq!(statements.len(), 3);
1156    }
1157
1158    #[test]
1159    fn test_mssql_policy_statements_to_sql() {
1160        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1161            .with_using("UserId = @UserId");
1162
1163        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1164        let full_sql = mssql.to_sql();
1165
1166        assert!(full_sql.contains("GO"));
1167        assert!(full_sql.contains("CREATE SCHEMA"));
1168        assert!(full_sql.contains("CREATE FUNCTION"));
1169        assert!(full_sql.contains("CREATE SECURITY POLICY"));
1170    }
1171
1172    #[test]
1173    fn test_policy_default_mssql_block_operations() {
1174        // Test INSERT only
1175        let insert_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1176            .with_commands(vec![PolicyCommand::Insert])
1177            .with_check("true");
1178
1179        let mssql = insert_policy.to_mssql_sql("dbo.Users", "UserId");
1180        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1181        assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1182
1183        // Test UPDATE only
1184        let update_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1185            .with_commands(vec![PolicyCommand::Update])
1186            .with_check("true");
1187
1188        let mssql = update_policy.to_mssql_sql("dbo.Users", "UserId");
1189        assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1190        assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1191        assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1192
1193        // Test DELETE only
1194        let delete_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1195            .with_commands(vec![PolicyCommand::Delete])
1196            .with_check("true");
1197
1198        let mssql = delete_policy.to_mssql_sql("dbo.Users", "UserId");
1199        assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1200        assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1201    }
1202
1203    #[test]
1204    fn test_policy_mssql_custom_block_operations() {
1205        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1206            .with_commands(vec![PolicyCommand::All])
1207            .with_check("true")
1208            .with_mssql_block_operations(vec![MssqlBlockOperation::AfterInsert]);
1209
1210        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1211
1212        // Should only have the custom operation, not all defaults
1213        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1214        assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1215        assert!(!mssql.policy_sql.contains("AFTER UPDATE"));
1216    }
1217
1218    // ==================== MSSQL Real-World Scenario Tests ====================
1219
1220    #[test]
1221    fn test_mssql_rls_scenario_user_isolation() {
1222        let policy = Policy::new(
1223            make_ident("user_isolation"),
1224            make_ident("User"),
1225            make_span(),
1226        )
1227        .with_mssql_schema("Security")
1228        .with_commands(vec![PolicyCommand::All])
1229        .with_using("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)")
1230        .with_check("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)");
1231
1232        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1233
1234        // Verify complete MSSQL RLS setup
1235        assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1236        assert!(mssql.function_sql.contains("fn_user_isolation_predicate"));
1237        assert!(mssql.policy_sql.contains("user_isolation"));
1238        assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1239    }
1240
1241    #[test]
1242    fn test_mssql_rls_scenario_multi_tenant() {
1243        let policy = Policy::new(
1244            make_ident("tenant_isolation"),
1245            make_ident("Order"),
1246            make_span(),
1247        )
1248        .with_mssql_schema("MultiTenant")
1249        .with_using("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)")
1250        .with_check("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)");
1251
1252        let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
1253
1254        assert!(mssql.schema_sql.contains("MultiTenant"));
1255        assert!(mssql.function_sql.contains("@TenantId AS INT"));
1256        assert!(mssql.policy_sql.contains("dbo.Orders"));
1257    }
1258}
1259