Skip to main content

prax_schema/ast/
policy.rs

1//! Row-Level Security (RLS) policy definitions for the Prax schema AST.
2//!
3//! Policies enable fine-grained access control at the row level.
4//! They are evaluated for each row that a query accesses and determine whether
5//! the row should be visible or modifiable based on the policy expression.
6//!
7//! ## Supported Databases
8//!
9//! - **PostgreSQL**: Native RLS with CREATE POLICY
10//! - **SQL Server (MSSQL)**: Security Policies with predicate functions
11//!
12//! ## PostgreSQL RLS
13//!
14//! PostgreSQL uses `CREATE POLICY` statements with USING (filter) and WITH CHECK
15//! (block) expressions evaluated inline.
16//!
17//! ## SQL Server RLS
18//!
19//! SQL Server requires:
20//! 1. A schema-bound inline table-valued function (predicate function)
21//! 2. A security policy binding the function to the table
22//!
23//! The predicate function returns 1 for rows that should be accessible.
24
25use serde::{Deserialize, Serialize};
26use smol_str::SmolStr;
27
28use super::{Documentation, Ident, Span};
29
30/// A Row-Level Security (RLS) policy definition.
31///
32/// Policies provide fine-grained access control at the row level.
33/// They are applied to tables and evaluated for each row operation.
34///
35/// # Example Schema Syntax
36///
37/// ```text
38/// policy UserReadOwnData on User {
39///     for     SELECT
40///     to      authenticated
41///     using   "id = current_user_id()"
42/// }
43///
44/// policy UserModifyOwnData on User {
45///     for     [INSERT, UPDATE, DELETE]
46///     to      authenticated
47///     using   "id = current_user_id()"
48///     check   "id = current_user_id()"
49/// }
50/// ```
51///
52/// # Database Support
53///
54/// - PostgreSQL: Full support via CREATE POLICY
55/// - SQL Server: Supported via Security Policies with predicate functions
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub struct Policy {
58    /// Policy name (must be unique per table).
59    pub name: Ident,
60    /// The model/table this policy applies to.
61    pub table: Ident,
62    /// Policy type: PERMISSIVE (default) or RESTRICTIVE.
63    pub policy_type: PolicyType,
64    /// Commands this policy applies to (SELECT, INSERT, UPDATE, DELETE, or ALL).
65    pub commands: Vec<PolicyCommand>,
66    /// Roles this policy applies to (default: PUBLIC).
67    pub roles: Vec<SmolStr>,
68    /// USING expression - evaluated for existing rows (SELECT, UPDATE, DELETE).
69    /// Should return boolean. Row is visible if expression returns true.
70    /// In MSSQL, this becomes the FILTER PREDICATE.
71    pub using_expr: Option<String>,
72    /// WITH CHECK expression - evaluated for new rows (INSERT, UPDATE).
73    /// Should return boolean. Row can be inserted/updated if expression returns true.
74    /// In MSSQL, this becomes BLOCK PREDICATE(s).
75    pub check_expr: Option<String>,
76    /// MSSQL-specific: Schema for the predicate function (default: "Security").
77    pub mssql_schema: Option<SmolStr>,
78    /// MSSQL-specific: Block operations to apply (default: all applicable).
79    pub mssql_block_operations: Vec<MssqlBlockOperation>,
80    /// Documentation comment.
81    pub documentation: Option<Documentation>,
82    /// Source location.
83    pub span: Span,
84    /// Source file this policy was parsed from (None for single-file path).
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub source_id: Option<crate::loader::SourceId>,
87}
88
89impl Policy {
90    /// Create a new policy.
91    pub fn new(name: Ident, table: Ident, span: Span) -> Self {
92        Self {
93            name,
94            table,
95            policy_type: PolicyType::Permissive,
96            commands: vec![PolicyCommand::All],
97            roles: vec![],
98            using_expr: None,
99            check_expr: None,
100            mssql_schema: None,
101            mssql_block_operations: vec![],
102            documentation: None,
103            source_id: None,
104            span,
105        }
106    }
107
108    /// Get the policy name as a string.
109    pub fn name(&self) -> &str {
110        self.name.as_str()
111    }
112
113    /// Get the table name as a string.
114    pub fn table(&self) -> &str {
115        self.table.as_str()
116    }
117
118    /// Set the policy type.
119    pub fn with_type(mut self, policy_type: PolicyType) -> Self {
120        self.policy_type = policy_type;
121        self
122    }
123
124    /// Set the commands this policy applies to.
125    pub fn with_commands(mut self, commands: Vec<PolicyCommand>) -> Self {
126        self.commands = commands;
127        self
128    }
129
130    /// Add a command this policy applies to.
131    pub fn add_command(&mut self, command: PolicyCommand) {
132        self.commands.push(command);
133    }
134
135    /// Set the roles this policy applies to.
136    pub fn with_roles(mut self, roles: Vec<SmolStr>) -> Self {
137        self.roles = roles;
138        self
139    }
140
141    /// Add a role this policy applies to.
142    pub fn add_role(&mut self, role: impl Into<SmolStr>) {
143        self.roles.push(role.into());
144    }
145
146    /// Set the USING expression.
147    pub fn with_using(mut self, expr: impl Into<String>) -> Self {
148        self.using_expr = Some(expr.into());
149        self
150    }
151
152    /// Set the WITH CHECK expression.
153    pub fn with_check(mut self, expr: impl Into<String>) -> Self {
154        self.check_expr = Some(expr.into());
155        self
156    }
157
158    /// Set documentation.
159    pub fn with_documentation(mut self, doc: Documentation) -> Self {
160        self.documentation = Some(doc);
161        self
162    }
163
164    /// Set the MSSQL schema for the predicate function.
165    pub fn with_mssql_schema(mut self, schema: impl Into<SmolStr>) -> Self {
166        self.mssql_schema = Some(schema.into());
167        self
168    }
169
170    /// Set the MSSQL block operations.
171    pub fn with_mssql_block_operations(mut self, operations: Vec<MssqlBlockOperation>) -> Self {
172        self.mssql_block_operations = operations;
173        self
174    }
175
176    /// Add an MSSQL block operation.
177    pub fn add_mssql_block_operation(&mut self, operation: MssqlBlockOperation) {
178        self.mssql_block_operations.push(operation);
179    }
180
181    /// Check if this policy applies to a specific command.
182    pub fn applies_to(&self, command: PolicyCommand) -> bool {
183        self.commands.contains(&PolicyCommand::All) || self.commands.contains(&command)
184    }
185
186    /// Check if this policy is restrictive.
187    pub fn is_restrictive(&self) -> bool {
188        self.policy_type == PolicyType::Restrictive
189    }
190
191    /// Check if this policy is permissive.
192    pub fn is_permissive(&self) -> bool {
193        self.policy_type == PolicyType::Permissive
194    }
195
196    /// Get the effective roles (PUBLIC if none specified).
197    pub fn effective_roles(&self) -> Vec<&str> {
198        if self.roles.is_empty() {
199            vec!["PUBLIC"]
200        } else {
201            self.roles.iter().map(|r| r.as_str()).collect()
202        }
203    }
204
205    /// Get the MSSQL schema (default: "Security").
206    pub fn mssql_schema(&self) -> &str {
207        self.mssql_schema
208            .as_ref()
209            .map(|s| s.as_str())
210            .unwrap_or("Security")
211    }
212
213    /// Get the predicate function name for MSSQL.
214    pub fn mssql_predicate_function_name(&self) -> String {
215        format!("fn_{}_predicate", self.name())
216    }
217
218    /// Generate the PostgreSQL CREATE POLICY statement.
219    pub fn to_sql(&self, table_name: &str) -> String {
220        self.to_postgres_sql(table_name)
221    }
222
223    /// Generate the PostgreSQL CREATE POLICY statement.
224    pub fn to_postgres_sql(&self, table_name: &str) -> String {
225        let mut sql = format!("CREATE POLICY {} ON {}", self.name(), table_name);
226
227        // AS PERMISSIVE/RESTRICTIVE
228        match self.policy_type {
229            PolicyType::Permissive => {} // Default, no need to specify
230            PolicyType::Restrictive => sql.push_str(" AS RESTRICTIVE"),
231        }
232
233        // FOR command
234        if !self.commands.is_empty() && !self.commands.contains(&PolicyCommand::All) {
235            let cmds: Vec<&str> = self.commands.iter().map(|c| c.as_str()).collect();
236            // PostgreSQL only allows a single command, use first one
237            sql.push_str(&format!(" FOR {}", cmds[0]));
238        }
239
240        // TO roles
241        let roles = self.effective_roles();
242        sql.push_str(&format!(" TO {}", roles.join(", ")));
243
244        // USING expression
245        if let Some(ref using) = self.using_expr {
246            sql.push_str(&format!(" USING ({})", using));
247        }
248
249        // WITH CHECK expression
250        if let Some(ref check) = self.check_expr {
251            sql.push_str(&format!(" WITH CHECK ({})", check));
252        }
253
254        sql
255    }
256
257    /// Generate SQL Server (MSSQL) security policy statements.
258    ///
259    /// Returns a tuple of:
260    /// 1. CREATE FUNCTION statement for the predicate function
261    /// 2. CREATE SECURITY POLICY statement
262    ///
263    /// # Arguments
264    ///
265    /// * `table_name` - The fully qualified table name (e.g., "dbo.Users")
266    /// * `predicate_column` - The column name to use in the predicate (e.g., "UserId")
267    ///
268    /// # Example
269    ///
270    /// ```ignore
271    /// let (func_sql, policy_sql) = policy.to_mssql_sql("dbo.Users", "UserId");
272    /// ```
273    pub fn to_mssql_sql(&self, table_name: &str, predicate_column: &str) -> MssqlPolicyStatements {
274        let schema = self.mssql_schema();
275        let func_name = self.mssql_predicate_function_name();
276
277        // Generate the predicate function
278        let filter_expr = self
279            .using_expr
280            .as_deref()
281            .unwrap_or("1 = 1")
282            .replace(
283                "current_user_id()",
284                "CAST(SESSION_CONTEXT(N'UserId') AS INT)",
285            )
286            .replace("auth.uid()", "CAST(SESSION_CONTEXT(N'UserId') AS INT)")
287            .replace(
288                "current_setting('app.current_org')",
289                "SESSION_CONTEXT(N'OrgId')",
290            );
291
292        let function_sql = format!(
293            r#"CREATE FUNCTION {schema}.{func_name}(@{predicate_column} AS INT)
294    RETURNS TABLE
295WITH SCHEMABINDING
296AS
297    RETURN SELECT 1 AS fn_securitypredicate_result
298    WHERE {filter_expr}"#,
299            schema = schema,
300            func_name = func_name,
301            predicate_column = predicate_column,
302            filter_expr = filter_expr
303        );
304
305        // Generate the security policy
306        let mut policy_sql = format!(
307            "CREATE SECURITY POLICY {schema}.{policy_name}\n",
308            schema = schema,
309            policy_name = self.name()
310        );
311
312        // Add FILTER PREDICATE if we have a using expression
313        if self.using_expr.is_some() {
314            policy_sql.push_str(&format!(
315                "ADD FILTER PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name}",
316                schema = schema,
317                func_name = func_name,
318                predicate_column = predicate_column,
319                table_name = table_name
320            ));
321        }
322
323        // Add BLOCK PREDICATE(s) if we have a check expression
324        if self.check_expr.is_some() {
325            let block_ops = if self.mssql_block_operations.is_empty() {
326                // Default block operations based on commands
327                self.default_mssql_block_operations()
328            } else {
329                self.mssql_block_operations.clone()
330            };
331
332            for (i, op) in block_ops.iter().enumerate() {
333                if i > 0 || self.using_expr.is_some() {
334                    policy_sql.push_str(",\n");
335                }
336                policy_sql.push_str(&format!(
337                    "ADD BLOCK PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name} {op}",
338                    schema = schema,
339                    func_name = func_name,
340                    predicate_column = predicate_column,
341                    table_name = table_name,
342                    op = op.as_str()
343                ));
344            }
345        }
346
347        policy_sql.push_str("\nWITH (STATE = ON)");
348
349        MssqlPolicyStatements {
350            schema_sql: format!("CREATE SCHEMA {schema}"),
351            function_sql,
352            policy_sql,
353        }
354    }
355
356    /// Get default MSSQL block operations based on the policy commands.
357    fn default_mssql_block_operations(&self) -> Vec<MssqlBlockOperation> {
358        let mut ops = vec![];
359
360        if self.applies_to(PolicyCommand::Insert) {
361            ops.push(MssqlBlockOperation::AfterInsert);
362        }
363        if self.applies_to(PolicyCommand::Update) {
364            ops.push(MssqlBlockOperation::AfterUpdate);
365            ops.push(MssqlBlockOperation::BeforeUpdate);
366        }
367        if self.applies_to(PolicyCommand::Delete) {
368            ops.push(MssqlBlockOperation::BeforeDelete);
369        }
370
371        ops
372    }
373}
374
375/// PostgreSQL policy type.
376#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
377pub enum PolicyType {
378    /// Permissive policies are combined with OR.
379    /// At least one permissive policy must allow access.
380    #[default]
381    Permissive,
382    /// Restrictive policies are combined with AND.
383    /// All restrictive policies must allow access.
384    Restrictive,
385}
386
387impl PolicyType {
388    /// Parse a policy type from a string.
389    #[allow(clippy::should_implement_trait)]
390    pub fn from_str(s: &str) -> Option<Self> {
391        match s.to_uppercase().as_str() {
392            "PERMISSIVE" => Some(Self::Permissive),
393            "RESTRICTIVE" => Some(Self::Restrictive),
394            _ => None,
395        }
396    }
397
398    /// Get the SQL keyword for this policy type.
399    pub fn as_str(&self) -> &'static str {
400        match self {
401            Self::Permissive => "PERMISSIVE",
402            Self::Restrictive => "RESTRICTIVE",
403        }
404    }
405}
406
407impl std::fmt::Display for PolicyType {
408    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409        write!(f, "{}", self.as_str())
410    }
411}
412
413/// PostgreSQL policy command type.
414#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
415pub enum PolicyCommand {
416    /// Policy applies to all commands (SELECT, INSERT, UPDATE, DELETE).
417    All,
418    /// Policy applies to SELECT queries.
419    Select,
420    /// Policy applies to INSERT statements.
421    Insert,
422    /// Policy applies to UPDATE statements.
423    Update,
424    /// Policy applies to DELETE statements.
425    Delete,
426}
427
428impl PolicyCommand {
429    /// Parse a policy command from a string.
430    #[allow(clippy::should_implement_trait)]
431    pub fn from_str(s: &str) -> Option<Self> {
432        match s.to_uppercase().as_str() {
433            "ALL" => Some(Self::All),
434            "SELECT" => Some(Self::Select),
435            "INSERT" => Some(Self::Insert),
436            "UPDATE" => Some(Self::Update),
437            "DELETE" => Some(Self::Delete),
438            _ => None,
439        }
440    }
441
442    /// Get the SQL keyword for this command.
443    pub fn as_str(&self) -> &'static str {
444        match self {
445            Self::All => "ALL",
446            Self::Select => "SELECT",
447            Self::Insert => "INSERT",
448            Self::Update => "UPDATE",
449            Self::Delete => "DELETE",
450        }
451    }
452
453    /// Check if this command requires a USING expression.
454    pub fn requires_using(&self) -> bool {
455        matches!(self, Self::All | Self::Select | Self::Update | Self::Delete)
456    }
457
458    /// Check if this command requires a WITH CHECK expression.
459    pub fn requires_check(&self) -> bool {
460        matches!(self, Self::All | Self::Insert | Self::Update)
461    }
462}
463
464impl std::fmt::Display for PolicyCommand {
465    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466        write!(f, "{}", self.as_str())
467    }
468}
469
470/// MSSQL-specific block operation types.
471///
472/// SQL Server's BLOCK PREDICATE can be applied at different points
473/// in the data modification lifecycle.
474#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
475pub enum MssqlBlockOperation {
476    /// Block predicate evaluated after INSERT.
477    /// Prevents inserting rows that don't satisfy the predicate.
478    AfterInsert,
479    /// Block predicate evaluated after UPDATE.
480    /// Prevents updating rows to values that don't satisfy the predicate.
481    AfterUpdate,
482    /// Block predicate evaluated before UPDATE.
483    /// Prevents updating rows that currently don't satisfy the predicate.
484    BeforeUpdate,
485    /// Block predicate evaluated before DELETE.
486    /// Prevents deleting rows that don't satisfy the predicate.
487    BeforeDelete,
488}
489
490impl MssqlBlockOperation {
491    /// Parse a block operation from a string.
492    #[allow(clippy::should_implement_trait)]
493    pub fn from_str(s: &str) -> Option<Self> {
494        match s.to_uppercase().replace([' ', '_'], "").as_str() {
495            "AFTERINSERT" => Some(Self::AfterInsert),
496            "AFTERUPDATE" => Some(Self::AfterUpdate),
497            "BEFOREUPDATE" => Some(Self::BeforeUpdate),
498            "BEFOREDELETE" => Some(Self::BeforeDelete),
499            _ => None,
500        }
501    }
502
503    /// Get the SQL clause for this block operation.
504    pub fn as_str(&self) -> &'static str {
505        match self {
506            Self::AfterInsert => "AFTER INSERT",
507            Self::AfterUpdate => "AFTER UPDATE",
508            Self::BeforeUpdate => "BEFORE UPDATE",
509            Self::BeforeDelete => "BEFORE DELETE",
510        }
511    }
512}
513
514impl std::fmt::Display for MssqlBlockOperation {
515    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516        write!(f, "{}", self.as_str())
517    }
518}
519
520/// SQL statements generated for MSSQL security policies.
521#[derive(Debug, Clone, PartialEq)]
522pub struct MssqlPolicyStatements {
523    /// CREATE SCHEMA statement (if the schema doesn't exist).
524    pub schema_sql: String,
525    /// CREATE FUNCTION statement for the predicate function.
526    pub function_sql: String,
527    /// CREATE SECURITY POLICY statement.
528    pub policy_sql: String,
529}
530
531impl MssqlPolicyStatements {
532    /// Get all SQL statements in execution order.
533    pub fn all_statements(&self) -> Vec<&str> {
534        vec![&self.schema_sql, &self.function_sql, &self.policy_sql]
535    }
536
537    /// Get all SQL as a single string with separators.
538    pub fn to_sql(&self) -> String {
539        format!(
540            "{schema_sql};\nGO\n\n{function_sql};\nGO\n\n{policy_sql};",
541            schema_sql = self.schema_sql,
542            function_sql = self.function_sql,
543            policy_sql = self.policy_sql
544        )
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    fn make_span() -> Span {
553        Span::new(0, 10)
554    }
555
556    fn make_ident(name: &str) -> Ident {
557        Ident::new(name, make_span())
558    }
559
560    // ==================== Policy Tests ====================
561
562    #[test]
563    fn test_policy_new() {
564        let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span());
565
566        assert_eq!(policy.name(), "read_own");
567        assert_eq!(policy.table(), "User");
568        assert_eq!(policy.policy_type, PolicyType::Permissive);
569        assert_eq!(policy.commands, vec![PolicyCommand::All]);
570        assert!(policy.roles.is_empty());
571        assert!(policy.using_expr.is_none());
572        assert!(policy.check_expr.is_none());
573        assert!(policy.documentation.is_none());
574    }
575
576    #[test]
577    fn test_policy_with_type() {
578        let policy = Policy::new(make_ident("strict"), make_ident("User"), make_span())
579            .with_type(PolicyType::Restrictive);
580
581        assert!(policy.is_restrictive());
582        assert!(!policy.is_permissive());
583    }
584
585    #[test]
586    fn test_policy_with_commands() {
587        let policy = Policy::new(make_ident("read"), make_ident("User"), make_span())
588            .with_commands(vec![PolicyCommand::Select]);
589
590        assert!(policy.applies_to(PolicyCommand::Select));
591        assert!(!policy.applies_to(PolicyCommand::Insert));
592        assert!(!policy.applies_to(PolicyCommand::Update));
593        assert!(!policy.applies_to(PolicyCommand::Delete));
594    }
595
596    #[test]
597    fn test_policy_with_multiple_commands() {
598        let policy = Policy::new(make_ident("read_update"), make_ident("User"), make_span())
599            .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update]);
600
601        assert!(policy.applies_to(PolicyCommand::Select));
602        assert!(policy.applies_to(PolicyCommand::Update));
603        assert!(!policy.applies_to(PolicyCommand::Insert));
604        assert!(!policy.applies_to(PolicyCommand::Delete));
605    }
606
607    #[test]
608    fn test_policy_all_command_applies_to_all() {
609        let policy = Policy::new(make_ident("all"), make_ident("User"), make_span())
610            .with_commands(vec![PolicyCommand::All]);
611
612        assert!(policy.applies_to(PolicyCommand::Select));
613        assert!(policy.applies_to(PolicyCommand::Insert));
614        assert!(policy.applies_to(PolicyCommand::Update));
615        assert!(policy.applies_to(PolicyCommand::Delete));
616        assert!(policy.applies_to(PolicyCommand::All));
617    }
618
619    #[test]
620    fn test_policy_add_command() {
621        let mut policy =
622            Policy::new(make_ident("test"), make_ident("User"), make_span()).with_commands(vec![]);
623
624        policy.add_command(PolicyCommand::Select);
625        policy.add_command(PolicyCommand::Update);
626
627        assert_eq!(policy.commands.len(), 2);
628        assert!(policy.applies_to(PolicyCommand::Select));
629        assert!(policy.applies_to(PolicyCommand::Update));
630    }
631
632    #[test]
633    fn test_policy_with_roles() {
634        let policy = Policy::new(make_ident("auth"), make_ident("User"), make_span())
635            .with_roles(vec!["authenticated".into(), "admin".into()]);
636
637        assert_eq!(policy.roles.len(), 2);
638        let roles = policy.effective_roles();
639        assert!(roles.contains(&"authenticated"));
640        assert!(roles.contains(&"admin"));
641    }
642
643    #[test]
644    fn test_policy_add_role() {
645        let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
646
647        policy.add_role("user");
648        policy.add_role("moderator");
649
650        assert_eq!(policy.roles.len(), 2);
651    }
652
653    #[test]
654    fn test_policy_effective_roles_default() {
655        let policy = Policy::new(make_ident("public"), make_ident("User"), make_span());
656
657        let roles = policy.effective_roles();
658        assert_eq!(roles, vec!["PUBLIC"]);
659    }
660
661    #[test]
662    fn test_policy_with_using() {
663        let policy = Policy::new(make_ident("own"), make_ident("User"), make_span())
664            .with_using("user_id = current_user_id()");
665
666        assert_eq!(
667            policy.using_expr.as_deref(),
668            Some("user_id = current_user_id()")
669        );
670    }
671
672    #[test]
673    fn test_policy_with_check() {
674        let policy = Policy::new(make_ident("insert"), make_ident("User"), make_span())
675            .with_check("user_id = current_user_id()");
676
677        assert_eq!(
678            policy.check_expr.as_deref(),
679            Some("user_id = current_user_id()")
680        );
681    }
682
683    #[test]
684    fn test_policy_with_documentation() {
685        let policy =
686            Policy::new(make_ident("doc"), make_ident("User"), make_span()).with_documentation(
687                Documentation::new("Users can only see their own data", make_span()),
688            );
689
690        assert!(policy.documentation.is_some());
691        assert_eq!(
692            policy.documentation.unwrap().text,
693            "Users can only see their own data"
694        );
695    }
696
697    #[test]
698    fn test_policy_to_sql_simple() {
699        let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span())
700            .with_commands(vec![PolicyCommand::Select])
701            .with_using("id = current_user_id()");
702
703        let sql = policy.to_sql("users");
704        assert!(sql.contains("CREATE POLICY read_own ON users"));
705        assert!(sql.contains("FOR SELECT"));
706        assert!(sql.contains("TO PUBLIC"));
707        assert!(sql.contains("USING (id = current_user_id())"));
708    }
709
710    #[test]
711    fn test_policy_to_sql_with_roles() {
712        let policy = Policy::new(make_ident("auth_read"), make_ident("User"), make_span())
713            .with_commands(vec![PolicyCommand::Select])
714            .with_roles(vec!["authenticated".into()])
715            .with_using("true");
716
717        let sql = policy.to_sql("users");
718        assert!(sql.contains("TO authenticated"));
719    }
720
721    #[test]
722    fn test_policy_to_sql_restrictive() {
723        let policy = Policy::new(make_ident("restrict"), make_ident("User"), make_span())
724            .with_type(PolicyType::Restrictive)
725            .with_using("org_id = current_org_id()");
726
727        let sql = policy.to_sql("users");
728        assert!(sql.contains("AS RESTRICTIVE"));
729    }
730
731    #[test]
732    fn test_policy_to_sql_with_check() {
733        let policy = Policy::new(make_ident("insert_own"), make_ident("User"), make_span())
734            .with_commands(vec![PolicyCommand::Insert])
735            .with_check("id = current_user_id()");
736
737        let sql = policy.to_sql("users");
738        assert!(sql.contains("FOR INSERT"));
739        assert!(sql.contains("WITH CHECK (id = current_user_id())"));
740    }
741
742    #[test]
743    fn test_policy_to_sql_both_expressions() {
744        let policy = Policy::new(make_ident("update_own"), make_ident("User"), make_span())
745            .with_commands(vec![PolicyCommand::Update])
746            .with_using("id = current_user_id()")
747            .with_check("id = current_user_id()");
748
749        let sql = policy.to_sql("users");
750        assert!(sql.contains("USING (id = current_user_id())"));
751        assert!(sql.contains("WITH CHECK (id = current_user_id())"));
752    }
753
754    #[test]
755    fn test_policy_equality() {
756        let policy1 = Policy::new(make_ident("test"), make_ident("User"), make_span());
757        let policy2 = Policy::new(make_ident("test"), make_ident("User"), make_span());
758
759        assert_eq!(policy1, policy2);
760    }
761
762    #[test]
763    fn test_policy_clone() {
764        let policy = Policy::new(make_ident("original"), make_ident("User"), make_span())
765            .with_using("id = 1");
766
767        let cloned = policy.clone();
768        assert_eq!(cloned.name(), "original");
769        assert_eq!(cloned.using_expr, Some("id = 1".to_string()));
770    }
771
772    // ==================== PolicyType Tests ====================
773
774    #[test]
775    fn test_policy_type_from_str() {
776        assert_eq!(
777            PolicyType::from_str("PERMISSIVE"),
778            Some(PolicyType::Permissive)
779        );
780        assert_eq!(
781            PolicyType::from_str("permissive"),
782            Some(PolicyType::Permissive)
783        );
784        assert_eq!(
785            PolicyType::from_str("Permissive"),
786            Some(PolicyType::Permissive)
787        );
788        assert_eq!(
789            PolicyType::from_str("RESTRICTIVE"),
790            Some(PolicyType::Restrictive)
791        );
792        assert_eq!(
793            PolicyType::from_str("restrictive"),
794            Some(PolicyType::Restrictive)
795        );
796        assert_eq!(PolicyType::from_str("invalid"), None);
797    }
798
799    #[test]
800    fn test_policy_type_as_str() {
801        assert_eq!(PolicyType::Permissive.as_str(), "PERMISSIVE");
802        assert_eq!(PolicyType::Restrictive.as_str(), "RESTRICTIVE");
803    }
804
805    #[test]
806    fn test_policy_type_display() {
807        assert_eq!(format!("{}", PolicyType::Permissive), "PERMISSIVE");
808        assert_eq!(format!("{}", PolicyType::Restrictive), "RESTRICTIVE");
809    }
810
811    #[test]
812    fn test_policy_type_default() {
813        let policy_type: PolicyType = Default::default();
814        assert_eq!(policy_type, PolicyType::Permissive);
815    }
816
817    #[test]
818    fn test_policy_type_equality() {
819        assert_eq!(PolicyType::Permissive, PolicyType::Permissive);
820        assert_eq!(PolicyType::Restrictive, PolicyType::Restrictive);
821        assert_ne!(PolicyType::Permissive, PolicyType::Restrictive);
822    }
823
824    // ==================== PolicyCommand Tests ====================
825
826    #[test]
827    fn test_policy_command_from_str() {
828        assert_eq!(PolicyCommand::from_str("ALL"), Some(PolicyCommand::All));
829        assert_eq!(PolicyCommand::from_str("all"), Some(PolicyCommand::All));
830        assert_eq!(
831            PolicyCommand::from_str("SELECT"),
832            Some(PolicyCommand::Select)
833        );
834        assert_eq!(
835            PolicyCommand::from_str("select"),
836            Some(PolicyCommand::Select)
837        );
838        assert_eq!(
839            PolicyCommand::from_str("INSERT"),
840            Some(PolicyCommand::Insert)
841        );
842        assert_eq!(
843            PolicyCommand::from_str("UPDATE"),
844            Some(PolicyCommand::Update)
845        );
846        assert_eq!(
847            PolicyCommand::from_str("DELETE"),
848            Some(PolicyCommand::Delete)
849        );
850        assert_eq!(PolicyCommand::from_str("invalid"), None);
851    }
852
853    #[test]
854    fn test_policy_command_as_str() {
855        assert_eq!(PolicyCommand::All.as_str(), "ALL");
856        assert_eq!(PolicyCommand::Select.as_str(), "SELECT");
857        assert_eq!(PolicyCommand::Insert.as_str(), "INSERT");
858        assert_eq!(PolicyCommand::Update.as_str(), "UPDATE");
859        assert_eq!(PolicyCommand::Delete.as_str(), "DELETE");
860    }
861
862    #[test]
863    fn test_policy_command_display() {
864        assert_eq!(format!("{}", PolicyCommand::All), "ALL");
865        assert_eq!(format!("{}", PolicyCommand::Select), "SELECT");
866        assert_eq!(format!("{}", PolicyCommand::Insert), "INSERT");
867        assert_eq!(format!("{}", PolicyCommand::Update), "UPDATE");
868        assert_eq!(format!("{}", PolicyCommand::Delete), "DELETE");
869    }
870
871    #[test]
872    fn test_policy_command_requires_using() {
873        assert!(PolicyCommand::All.requires_using());
874        assert!(PolicyCommand::Select.requires_using());
875        assert!(PolicyCommand::Update.requires_using());
876        assert!(PolicyCommand::Delete.requires_using());
877        assert!(!PolicyCommand::Insert.requires_using());
878    }
879
880    #[test]
881    fn test_policy_command_requires_check() {
882        assert!(PolicyCommand::All.requires_check());
883        assert!(PolicyCommand::Insert.requires_check());
884        assert!(PolicyCommand::Update.requires_check());
885        assert!(!PolicyCommand::Select.requires_check());
886        assert!(!PolicyCommand::Delete.requires_check());
887    }
888
889    #[test]
890    fn test_policy_command_equality() {
891        assert_eq!(PolicyCommand::Select, PolicyCommand::Select);
892        assert_ne!(PolicyCommand::Select, PolicyCommand::Insert);
893    }
894
895    // ==================== Full Policy Scenario Tests ====================
896
897    #[test]
898    fn test_policy_rls_scenario_user_isolation() {
899        // Scenario: Users can only see and modify their own records
900        let policy = Policy::new(
901            make_ident("user_isolation"),
902            make_ident("User"),
903            make_span(),
904        )
905        .with_type(PolicyType::Permissive)
906        .with_commands(vec![PolicyCommand::All])
907        .with_roles(vec!["authenticated".into()])
908        .with_using("id = auth.uid()")
909        .with_check("id = auth.uid()");
910
911        assert!(policy.is_permissive());
912        assert!(policy.applies_to(PolicyCommand::Select));
913        assert!(policy.applies_to(PolicyCommand::Insert));
914        assert!(policy.applies_to(PolicyCommand::Update));
915        assert!(policy.applies_to(PolicyCommand::Delete));
916
917        let sql = policy.to_sql("users");
918        assert!(sql.contains("auth.uid()"));
919    }
920
921    #[test]
922    fn test_policy_rls_scenario_org_based() {
923        // Scenario: Users can only access records in their organization
924        let policy = Policy::new(
925            make_ident("org_access"),
926            make_ident("Document"),
927            make_span(),
928        )
929        .with_type(PolicyType::Restrictive)
930        .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update])
931        .with_using("org_id = current_setting('app.current_org')::uuid");
932
933        assert!(policy.is_restrictive());
934        assert!(policy.applies_to(PolicyCommand::Select));
935        assert!(policy.applies_to(PolicyCommand::Update));
936        assert!(!policy.applies_to(PolicyCommand::Delete));
937
938        let sql = policy.to_sql("documents");
939        assert!(sql.contains("AS RESTRICTIVE"));
940        assert!(sql.contains("current_setting"));
941    }
942
943    #[test]
944    fn test_policy_rls_scenario_public_read() {
945        // Scenario: Anyone can read, only owner can modify
946        let read_policy = Policy::new(make_ident("public_read"), make_ident("Post"), make_span())
947            .with_commands(vec![PolicyCommand::Select])
948            .with_using("published = true OR author_id = current_user_id()");
949
950        let write_policy = Policy::new(make_ident("owner_write"), make_ident("Post"), make_span())
951            .with_commands(vec![PolicyCommand::Update, PolicyCommand::Delete])
952            .with_roles(vec!["authenticated".into()])
953            .with_using("author_id = current_user_id()");
954
955        assert_eq!(read_policy.effective_roles(), vec!["PUBLIC"]);
956        assert!(write_policy.effective_roles().contains(&"authenticated"));
957    }
958
959    // ==================== MSSQL Block Operation Tests ====================
960
961    #[test]
962    fn test_mssql_block_operation_from_str() {
963        assert_eq!(
964            MssqlBlockOperation::from_str("AFTER INSERT"),
965            Some(MssqlBlockOperation::AfterInsert)
966        );
967        assert_eq!(
968            MssqlBlockOperation::from_str("after_insert"),
969            Some(MssqlBlockOperation::AfterInsert)
970        );
971        assert_eq!(
972            MssqlBlockOperation::from_str("AFTERINSERT"),
973            Some(MssqlBlockOperation::AfterInsert)
974        );
975        assert_eq!(
976            MssqlBlockOperation::from_str("AFTER UPDATE"),
977            Some(MssqlBlockOperation::AfterUpdate)
978        );
979        assert_eq!(
980            MssqlBlockOperation::from_str("BEFORE UPDATE"),
981            Some(MssqlBlockOperation::BeforeUpdate)
982        );
983        assert_eq!(
984            MssqlBlockOperation::from_str("BEFORE DELETE"),
985            Some(MssqlBlockOperation::BeforeDelete)
986        );
987        assert_eq!(MssqlBlockOperation::from_str("invalid"), None);
988    }
989
990    #[test]
991    fn test_mssql_block_operation_as_str() {
992        assert_eq!(MssqlBlockOperation::AfterInsert.as_str(), "AFTER INSERT");
993        assert_eq!(MssqlBlockOperation::AfterUpdate.as_str(), "AFTER UPDATE");
994        assert_eq!(MssqlBlockOperation::BeforeUpdate.as_str(), "BEFORE UPDATE");
995        assert_eq!(MssqlBlockOperation::BeforeDelete.as_str(), "BEFORE DELETE");
996    }
997
998    #[test]
999    fn test_mssql_block_operation_display() {
1000        assert_eq!(
1001            format!("{}", MssqlBlockOperation::AfterInsert),
1002            "AFTER INSERT"
1003        );
1004        assert_eq!(
1005            format!("{}", MssqlBlockOperation::BeforeDelete),
1006            "BEFORE DELETE"
1007        );
1008    }
1009
1010    // ==================== MSSQL Policy Tests ====================
1011
1012    #[test]
1013    fn test_policy_mssql_schema_default() {
1014        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1015        assert_eq!(policy.mssql_schema(), "Security");
1016    }
1017
1018    #[test]
1019    fn test_policy_with_mssql_schema() {
1020        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1021            .with_mssql_schema("RLS");
1022        assert_eq!(policy.mssql_schema(), "RLS");
1023    }
1024
1025    #[test]
1026    fn test_policy_mssql_predicate_function_name() {
1027        let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span());
1028        assert_eq!(
1029            policy.mssql_predicate_function_name(),
1030            "fn_user_filter_predicate"
1031        );
1032    }
1033
1034    #[test]
1035    fn test_policy_with_mssql_block_operations() {
1036        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1037            .with_mssql_block_operations(vec![
1038                MssqlBlockOperation::AfterInsert,
1039                MssqlBlockOperation::AfterUpdate,
1040            ]);
1041
1042        assert_eq!(policy.mssql_block_operations.len(), 2);
1043    }
1044
1045    #[test]
1046    fn test_policy_add_mssql_block_operation() {
1047        let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1048
1049        policy.add_mssql_block_operation(MssqlBlockOperation::AfterInsert);
1050        policy.add_mssql_block_operation(MssqlBlockOperation::BeforeDelete);
1051
1052        assert_eq!(policy.mssql_block_operations.len(), 2);
1053    }
1054
1055    #[test]
1056    fn test_policy_to_mssql_sql_simple() {
1057        let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span())
1058            .with_commands(vec![PolicyCommand::Select])
1059            .with_using("UserId = @UserId");
1060
1061        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1062
1063        // Check schema creation
1064        assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1065
1066        // Check function creation
1067        assert!(
1068            mssql
1069                .function_sql
1070                .contains("CREATE FUNCTION Security.fn_user_filter_predicate")
1071        );
1072        assert!(mssql.function_sql.contains("@UserId AS INT"));
1073        assert!(mssql.function_sql.contains("WITH SCHEMABINDING"));
1074        assert!(mssql.function_sql.contains("RETURNS TABLE"));
1075
1076        // Check policy creation
1077        assert!(
1078            mssql
1079                .policy_sql
1080                .contains("CREATE SECURITY POLICY Security.user_filter")
1081        );
1082        assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1083        assert!(mssql.policy_sql.contains("ON dbo.Users"));
1084        assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1085    }
1086
1087    #[test]
1088    fn test_policy_to_mssql_sql_with_check() {
1089        let policy = Policy::new(make_ident("user_insert"), make_ident("User"), make_span())
1090            .with_commands(vec![PolicyCommand::Insert])
1091            .with_check("UserId = @UserId");
1092
1093        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1094
1095        assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1096        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1097    }
1098
1099    #[test]
1100    fn test_policy_to_mssql_sql_with_both() {
1101        let policy = Policy::new(make_ident("user_all"), make_ident("User"), make_span())
1102            .with_commands(vec![PolicyCommand::All])
1103            .with_using("UserId = @UserId")
1104            .with_check("UserId = @UserId");
1105
1106        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1107
1108        assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1109        assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1110        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1111        assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1112        assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1113        assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1114    }
1115
1116    #[test]
1117    fn test_policy_to_mssql_sql_custom_schema() {
1118        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1119            .with_mssql_schema("RLS")
1120            .with_using("UserId = @UserId");
1121
1122        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1123
1124        assert!(mssql.schema_sql.contains("CREATE SCHEMA RLS"));
1125        assert!(mssql.function_sql.contains("RLS.fn_test_predicate"));
1126        assert!(mssql.policy_sql.contains("RLS.test"));
1127    }
1128
1129    #[test]
1130    fn test_policy_to_mssql_sql_translates_postgres_functions() {
1131        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1132            .with_using("id = current_user_id()");
1133
1134        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1135
1136        assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1137        assert!(!mssql.function_sql.contains("current_user_id"));
1138    }
1139
1140    #[test]
1141    fn test_policy_to_mssql_sql_translates_auth_uid() {
1142        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1143            .with_using("id = auth.uid()");
1144
1145        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1146
1147        assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1148        assert!(!mssql.function_sql.contains("auth.uid"));
1149    }
1150
1151    #[test]
1152    fn test_mssql_policy_statements_all_statements() {
1153        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1154            .with_using("UserId = @UserId");
1155
1156        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1157        let statements = mssql.all_statements();
1158
1159        assert_eq!(statements.len(), 3);
1160    }
1161
1162    #[test]
1163    fn test_mssql_policy_statements_to_sql() {
1164        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1165            .with_using("UserId = @UserId");
1166
1167        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1168        let full_sql = mssql.to_sql();
1169
1170        assert!(full_sql.contains("GO"));
1171        assert!(full_sql.contains("CREATE SCHEMA"));
1172        assert!(full_sql.contains("CREATE FUNCTION"));
1173        assert!(full_sql.contains("CREATE SECURITY POLICY"));
1174    }
1175
1176    #[test]
1177    fn test_policy_default_mssql_block_operations() {
1178        // Test INSERT only
1179        let insert_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1180            .with_commands(vec![PolicyCommand::Insert])
1181            .with_check("true");
1182
1183        let mssql = insert_policy.to_mssql_sql("dbo.Users", "UserId");
1184        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1185        assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1186
1187        // Test UPDATE only
1188        let update_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1189            .with_commands(vec![PolicyCommand::Update])
1190            .with_check("true");
1191
1192        let mssql = update_policy.to_mssql_sql("dbo.Users", "UserId");
1193        assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1194        assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1195        assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1196
1197        // Test DELETE only
1198        let delete_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1199            .with_commands(vec![PolicyCommand::Delete])
1200            .with_check("true");
1201
1202        let mssql = delete_policy.to_mssql_sql("dbo.Users", "UserId");
1203        assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1204        assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1205    }
1206
1207    #[test]
1208    fn test_policy_mssql_custom_block_operations() {
1209        let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1210            .with_commands(vec![PolicyCommand::All])
1211            .with_check("true")
1212            .with_mssql_block_operations(vec![MssqlBlockOperation::AfterInsert]);
1213
1214        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1215
1216        // Should only have the custom operation, not all defaults
1217        assert!(mssql.policy_sql.contains("AFTER INSERT"));
1218        assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1219        assert!(!mssql.policy_sql.contains("AFTER UPDATE"));
1220    }
1221
1222    // ==================== MSSQL Real-World Scenario Tests ====================
1223
1224    #[test]
1225    fn test_mssql_rls_scenario_user_isolation() {
1226        let policy = Policy::new(
1227            make_ident("user_isolation"),
1228            make_ident("User"),
1229            make_span(),
1230        )
1231        .with_mssql_schema("Security")
1232        .with_commands(vec![PolicyCommand::All])
1233        .with_using("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)")
1234        .with_check("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)");
1235
1236        let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1237
1238        // Verify complete MSSQL RLS setup
1239        assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1240        assert!(mssql.function_sql.contains("fn_user_isolation_predicate"));
1241        assert!(mssql.policy_sql.contains("user_isolation"));
1242        assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1243    }
1244
1245    #[test]
1246    fn test_mssql_rls_scenario_multi_tenant() {
1247        let policy = Policy::new(
1248            make_ident("tenant_isolation"),
1249            make_ident("Order"),
1250            make_span(),
1251        )
1252        .with_mssql_schema("MultiTenant")
1253        .with_using("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)")
1254        .with_check("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)");
1255
1256        let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
1257
1258        assert!(mssql.schema_sql.contains("MultiTenant"));
1259        assert!(mssql.function_sql.contains("@TenantId AS INT"));
1260        assert!(mssql.policy_sql.contains("dbo.Orders"));
1261    }
1262}