1use serde::{Deserialize, Serialize};
26use smol_str::SmolStr;
27
28use super::{Documentation, Ident, Span};
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub struct Policy {
58 pub name: Ident,
60 pub table: Ident,
62 pub policy_type: PolicyType,
64 pub commands: Vec<PolicyCommand>,
66 pub roles: Vec<SmolStr>,
68 pub using_expr: Option<String>,
72 pub check_expr: Option<String>,
76 pub mssql_schema: Option<SmolStr>,
78 pub mssql_block_operations: Vec<MssqlBlockOperation>,
80 pub documentation: Option<Documentation>,
82 pub span: Span,
84}
85
86impl Policy {
87 pub fn new(name: Ident, table: Ident, span: Span) -> Self {
89 Self {
90 name,
91 table,
92 policy_type: PolicyType::Permissive,
93 commands: vec![PolicyCommand::All],
94 roles: vec![],
95 using_expr: None,
96 check_expr: None,
97 mssql_schema: None,
98 mssql_block_operations: vec![],
99 documentation: None,
100 span,
101 }
102 }
103
104 pub fn name(&self) -> &str {
106 self.name.as_str()
107 }
108
109 pub fn table(&self) -> &str {
111 self.table.as_str()
112 }
113
114 pub fn with_type(mut self, policy_type: PolicyType) -> Self {
116 self.policy_type = policy_type;
117 self
118 }
119
120 pub fn with_commands(mut self, commands: Vec<PolicyCommand>) -> Self {
122 self.commands = commands;
123 self
124 }
125
126 pub fn add_command(&mut self, command: PolicyCommand) {
128 self.commands.push(command);
129 }
130
131 pub fn with_roles(mut self, roles: Vec<SmolStr>) -> Self {
133 self.roles = roles;
134 self
135 }
136
137 pub fn add_role(&mut self, role: impl Into<SmolStr>) {
139 self.roles.push(role.into());
140 }
141
142 pub fn with_using(mut self, expr: impl Into<String>) -> Self {
144 self.using_expr = Some(expr.into());
145 self
146 }
147
148 pub fn with_check(mut self, expr: impl Into<String>) -> Self {
150 self.check_expr = Some(expr.into());
151 self
152 }
153
154 pub fn with_documentation(mut self, doc: Documentation) -> Self {
156 self.documentation = Some(doc);
157 self
158 }
159
160 pub fn with_mssql_schema(mut self, schema: impl Into<SmolStr>) -> Self {
162 self.mssql_schema = Some(schema.into());
163 self
164 }
165
166 pub fn with_mssql_block_operations(mut self, operations: Vec<MssqlBlockOperation>) -> Self {
168 self.mssql_block_operations = operations;
169 self
170 }
171
172 pub fn add_mssql_block_operation(&mut self, operation: MssqlBlockOperation) {
174 self.mssql_block_operations.push(operation);
175 }
176
177 pub fn applies_to(&self, command: PolicyCommand) -> bool {
179 self.commands.contains(&PolicyCommand::All) || self.commands.contains(&command)
180 }
181
182 pub fn is_restrictive(&self) -> bool {
184 self.policy_type == PolicyType::Restrictive
185 }
186
187 pub fn is_permissive(&self) -> bool {
189 self.policy_type == PolicyType::Permissive
190 }
191
192 pub fn effective_roles(&self) -> Vec<&str> {
194 if self.roles.is_empty() {
195 vec!["PUBLIC"]
196 } else {
197 self.roles.iter().map(|r| r.as_str()).collect()
198 }
199 }
200
201 pub fn mssql_schema(&self) -> &str {
203 self.mssql_schema
204 .as_ref()
205 .map(|s| s.as_str())
206 .unwrap_or("Security")
207 }
208
209 pub fn mssql_predicate_function_name(&self) -> String {
211 format!("fn_{}_predicate", self.name())
212 }
213
214 pub fn to_sql(&self, table_name: &str) -> String {
216 self.to_postgres_sql(table_name)
217 }
218
219 pub fn to_postgres_sql(&self, table_name: &str) -> String {
221 let mut sql = format!("CREATE POLICY {} ON {}", self.name(), table_name);
222
223 match self.policy_type {
225 PolicyType::Permissive => {} PolicyType::Restrictive => sql.push_str(" AS RESTRICTIVE"),
227 }
228
229 if !self.commands.is_empty() && !self.commands.contains(&PolicyCommand::All) {
231 let cmds: Vec<&str> = self.commands.iter().map(|c| c.as_str()).collect();
232 sql.push_str(&format!(" FOR {}", cmds[0]));
234 }
235
236 let roles = self.effective_roles();
238 sql.push_str(&format!(" TO {}", roles.join(", ")));
239
240 if let Some(ref using) = self.using_expr {
242 sql.push_str(&format!(" USING ({})", using));
243 }
244
245 if let Some(ref check) = self.check_expr {
247 sql.push_str(&format!(" WITH CHECK ({})", check));
248 }
249
250 sql
251 }
252
253 pub fn to_mssql_sql(&self, table_name: &str, predicate_column: &str) -> MssqlPolicyStatements {
270 let schema = self.mssql_schema();
271 let func_name = self.mssql_predicate_function_name();
272
273 let filter_expr = self
275 .using_expr
276 .as_deref()
277 .unwrap_or("1 = 1")
278 .replace(
279 "current_user_id()",
280 "CAST(SESSION_CONTEXT(N'UserId') AS INT)",
281 )
282 .replace("auth.uid()", "CAST(SESSION_CONTEXT(N'UserId') AS INT)")
283 .replace(
284 "current_setting('app.current_org')",
285 "SESSION_CONTEXT(N'OrgId')",
286 );
287
288 let function_sql = format!(
289 r#"CREATE FUNCTION {schema}.{func_name}(@{predicate_column} AS INT)
290 RETURNS TABLE
291WITH SCHEMABINDING
292AS
293 RETURN SELECT 1 AS fn_securitypredicate_result
294 WHERE {filter_expr}"#,
295 schema = schema,
296 func_name = func_name,
297 predicate_column = predicate_column,
298 filter_expr = filter_expr
299 );
300
301 let mut policy_sql = format!(
303 "CREATE SECURITY POLICY {schema}.{policy_name}\n",
304 schema = schema,
305 policy_name = self.name()
306 );
307
308 if self.using_expr.is_some() {
310 policy_sql.push_str(&format!(
311 "ADD FILTER PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name}",
312 schema = schema,
313 func_name = func_name,
314 predicate_column = predicate_column,
315 table_name = table_name
316 ));
317 }
318
319 if self.check_expr.is_some() {
321 let block_ops = if self.mssql_block_operations.is_empty() {
322 self.default_mssql_block_operations()
324 } else {
325 self.mssql_block_operations.clone()
326 };
327
328 for (i, op) in block_ops.iter().enumerate() {
329 if i > 0 || self.using_expr.is_some() {
330 policy_sql.push_str(",\n");
331 }
332 policy_sql.push_str(&format!(
333 "ADD BLOCK PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name} {op}",
334 schema = schema,
335 func_name = func_name,
336 predicate_column = predicate_column,
337 table_name = table_name,
338 op = op.as_str()
339 ));
340 }
341 }
342
343 policy_sql.push_str("\nWITH (STATE = ON)");
344
345 MssqlPolicyStatements {
346 schema_sql: format!("CREATE SCHEMA {schema}"),
347 function_sql,
348 policy_sql,
349 }
350 }
351
352 fn default_mssql_block_operations(&self) -> Vec<MssqlBlockOperation> {
354 let mut ops = vec![];
355
356 if self.applies_to(PolicyCommand::Insert) {
357 ops.push(MssqlBlockOperation::AfterInsert);
358 }
359 if self.applies_to(PolicyCommand::Update) {
360 ops.push(MssqlBlockOperation::AfterUpdate);
361 ops.push(MssqlBlockOperation::BeforeUpdate);
362 }
363 if self.applies_to(PolicyCommand::Delete) {
364 ops.push(MssqlBlockOperation::BeforeDelete);
365 }
366
367 ops
368 }
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
373pub enum PolicyType {
374 #[default]
377 Permissive,
378 Restrictive,
381}
382
383impl PolicyType {
384 #[allow(clippy::should_implement_trait)]
386 pub fn from_str(s: &str) -> Option<Self> {
387 match s.to_uppercase().as_str() {
388 "PERMISSIVE" => Some(Self::Permissive),
389 "RESTRICTIVE" => Some(Self::Restrictive),
390 _ => None,
391 }
392 }
393
394 pub fn as_str(&self) -> &'static str {
396 match self {
397 Self::Permissive => "PERMISSIVE",
398 Self::Restrictive => "RESTRICTIVE",
399 }
400 }
401}
402
403impl std::fmt::Display for PolicyType {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 write!(f, "{}", self.as_str())
406 }
407}
408
409#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
411pub enum PolicyCommand {
412 All,
414 Select,
416 Insert,
418 Update,
420 Delete,
422}
423
424impl PolicyCommand {
425 #[allow(clippy::should_implement_trait)]
427 pub fn from_str(s: &str) -> Option<Self> {
428 match s.to_uppercase().as_str() {
429 "ALL" => Some(Self::All),
430 "SELECT" => Some(Self::Select),
431 "INSERT" => Some(Self::Insert),
432 "UPDATE" => Some(Self::Update),
433 "DELETE" => Some(Self::Delete),
434 _ => None,
435 }
436 }
437
438 pub fn as_str(&self) -> &'static str {
440 match self {
441 Self::All => "ALL",
442 Self::Select => "SELECT",
443 Self::Insert => "INSERT",
444 Self::Update => "UPDATE",
445 Self::Delete => "DELETE",
446 }
447 }
448
449 pub fn requires_using(&self) -> bool {
451 matches!(self, Self::All | Self::Select | Self::Update | Self::Delete)
452 }
453
454 pub fn requires_check(&self) -> bool {
456 matches!(self, Self::All | Self::Insert | Self::Update)
457 }
458}
459
460impl std::fmt::Display for PolicyCommand {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 write!(f, "{}", self.as_str())
463 }
464}
465
466#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
471pub enum MssqlBlockOperation {
472 AfterInsert,
475 AfterUpdate,
478 BeforeUpdate,
481 BeforeDelete,
484}
485
486impl MssqlBlockOperation {
487 #[allow(clippy::should_implement_trait)]
489 pub fn from_str(s: &str) -> Option<Self> {
490 match s.to_uppercase().replace([' ', '_'], "").as_str() {
491 "AFTERINSERT" => Some(Self::AfterInsert),
492 "AFTERUPDATE" => Some(Self::AfterUpdate),
493 "BEFOREUPDATE" => Some(Self::BeforeUpdate),
494 "BEFOREDELETE" => Some(Self::BeforeDelete),
495 _ => None,
496 }
497 }
498
499 pub fn as_str(&self) -> &'static str {
501 match self {
502 Self::AfterInsert => "AFTER INSERT",
503 Self::AfterUpdate => "AFTER UPDATE",
504 Self::BeforeUpdate => "BEFORE UPDATE",
505 Self::BeforeDelete => "BEFORE DELETE",
506 }
507 }
508}
509
510impl std::fmt::Display for MssqlBlockOperation {
511 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512 write!(f, "{}", self.as_str())
513 }
514}
515
516#[derive(Debug, Clone, PartialEq)]
518pub struct MssqlPolicyStatements {
519 pub schema_sql: String,
521 pub function_sql: String,
523 pub policy_sql: String,
525}
526
527impl MssqlPolicyStatements {
528 pub fn all_statements(&self) -> Vec<&str> {
530 vec![&self.schema_sql, &self.function_sql, &self.policy_sql]
531 }
532
533 pub fn to_sql(&self) -> String {
535 format!(
536 "{schema_sql};\nGO\n\n{function_sql};\nGO\n\n{policy_sql};",
537 schema_sql = self.schema_sql,
538 function_sql = self.function_sql,
539 policy_sql = self.policy_sql
540 )
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 fn make_span() -> Span {
549 Span::new(0, 10)
550 }
551
552 fn make_ident(name: &str) -> Ident {
553 Ident::new(name, make_span())
554 }
555
556 #[test]
559 fn test_policy_new() {
560 let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span());
561
562 assert_eq!(policy.name(), "read_own");
563 assert_eq!(policy.table(), "User");
564 assert_eq!(policy.policy_type, PolicyType::Permissive);
565 assert_eq!(policy.commands, vec![PolicyCommand::All]);
566 assert!(policy.roles.is_empty());
567 assert!(policy.using_expr.is_none());
568 assert!(policy.check_expr.is_none());
569 assert!(policy.documentation.is_none());
570 }
571
572 #[test]
573 fn test_policy_with_type() {
574 let policy = Policy::new(make_ident("strict"), make_ident("User"), make_span())
575 .with_type(PolicyType::Restrictive);
576
577 assert!(policy.is_restrictive());
578 assert!(!policy.is_permissive());
579 }
580
581 #[test]
582 fn test_policy_with_commands() {
583 let policy = Policy::new(make_ident("read"), make_ident("User"), make_span())
584 .with_commands(vec![PolicyCommand::Select]);
585
586 assert!(policy.applies_to(PolicyCommand::Select));
587 assert!(!policy.applies_to(PolicyCommand::Insert));
588 assert!(!policy.applies_to(PolicyCommand::Update));
589 assert!(!policy.applies_to(PolicyCommand::Delete));
590 }
591
592 #[test]
593 fn test_policy_with_multiple_commands() {
594 let policy = Policy::new(make_ident("read_update"), make_ident("User"), make_span())
595 .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update]);
596
597 assert!(policy.applies_to(PolicyCommand::Select));
598 assert!(policy.applies_to(PolicyCommand::Update));
599 assert!(!policy.applies_to(PolicyCommand::Insert));
600 assert!(!policy.applies_to(PolicyCommand::Delete));
601 }
602
603 #[test]
604 fn test_policy_all_command_applies_to_all() {
605 let policy = Policy::new(make_ident("all"), make_ident("User"), make_span())
606 .with_commands(vec![PolicyCommand::All]);
607
608 assert!(policy.applies_to(PolicyCommand::Select));
609 assert!(policy.applies_to(PolicyCommand::Insert));
610 assert!(policy.applies_to(PolicyCommand::Update));
611 assert!(policy.applies_to(PolicyCommand::Delete));
612 assert!(policy.applies_to(PolicyCommand::All));
613 }
614
615 #[test]
616 fn test_policy_add_command() {
617 let mut policy =
618 Policy::new(make_ident("test"), make_ident("User"), make_span()).with_commands(vec![]);
619
620 policy.add_command(PolicyCommand::Select);
621 policy.add_command(PolicyCommand::Update);
622
623 assert_eq!(policy.commands.len(), 2);
624 assert!(policy.applies_to(PolicyCommand::Select));
625 assert!(policy.applies_to(PolicyCommand::Update));
626 }
627
628 #[test]
629 fn test_policy_with_roles() {
630 let policy = Policy::new(make_ident("auth"), make_ident("User"), make_span())
631 .with_roles(vec!["authenticated".into(), "admin".into()]);
632
633 assert_eq!(policy.roles.len(), 2);
634 let roles = policy.effective_roles();
635 assert!(roles.contains(&"authenticated"));
636 assert!(roles.contains(&"admin"));
637 }
638
639 #[test]
640 fn test_policy_add_role() {
641 let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
642
643 policy.add_role("user");
644 policy.add_role("moderator");
645
646 assert_eq!(policy.roles.len(), 2);
647 }
648
649 #[test]
650 fn test_policy_effective_roles_default() {
651 let policy = Policy::new(make_ident("public"), make_ident("User"), make_span());
652
653 let roles = policy.effective_roles();
654 assert_eq!(roles, vec!["PUBLIC"]);
655 }
656
657 #[test]
658 fn test_policy_with_using() {
659 let policy = Policy::new(make_ident("own"), make_ident("User"), make_span())
660 .with_using("user_id = current_user_id()");
661
662 assert_eq!(
663 policy.using_expr.as_deref(),
664 Some("user_id = current_user_id()")
665 );
666 }
667
668 #[test]
669 fn test_policy_with_check() {
670 let policy = Policy::new(make_ident("insert"), make_ident("User"), make_span())
671 .with_check("user_id = current_user_id()");
672
673 assert_eq!(
674 policy.check_expr.as_deref(),
675 Some("user_id = current_user_id()")
676 );
677 }
678
679 #[test]
680 fn test_policy_with_documentation() {
681 let policy =
682 Policy::new(make_ident("doc"), make_ident("User"), make_span()).with_documentation(
683 Documentation::new("Users can only see their own data", make_span()),
684 );
685
686 assert!(policy.documentation.is_some());
687 assert_eq!(
688 policy.documentation.unwrap().text,
689 "Users can only see their own data"
690 );
691 }
692
693 #[test]
694 fn test_policy_to_sql_simple() {
695 let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span())
696 .with_commands(vec![PolicyCommand::Select])
697 .with_using("id = current_user_id()");
698
699 let sql = policy.to_sql("users");
700 assert!(sql.contains("CREATE POLICY read_own ON users"));
701 assert!(sql.contains("FOR SELECT"));
702 assert!(sql.contains("TO PUBLIC"));
703 assert!(sql.contains("USING (id = current_user_id())"));
704 }
705
706 #[test]
707 fn test_policy_to_sql_with_roles() {
708 let policy = Policy::new(make_ident("auth_read"), make_ident("User"), make_span())
709 .with_commands(vec![PolicyCommand::Select])
710 .with_roles(vec!["authenticated".into()])
711 .with_using("true");
712
713 let sql = policy.to_sql("users");
714 assert!(sql.contains("TO authenticated"));
715 }
716
717 #[test]
718 fn test_policy_to_sql_restrictive() {
719 let policy = Policy::new(make_ident("restrict"), make_ident("User"), make_span())
720 .with_type(PolicyType::Restrictive)
721 .with_using("org_id = current_org_id()");
722
723 let sql = policy.to_sql("users");
724 assert!(sql.contains("AS RESTRICTIVE"));
725 }
726
727 #[test]
728 fn test_policy_to_sql_with_check() {
729 let policy = Policy::new(make_ident("insert_own"), make_ident("User"), make_span())
730 .with_commands(vec![PolicyCommand::Insert])
731 .with_check("id = current_user_id()");
732
733 let sql = policy.to_sql("users");
734 assert!(sql.contains("FOR INSERT"));
735 assert!(sql.contains("WITH CHECK (id = current_user_id())"));
736 }
737
738 #[test]
739 fn test_policy_to_sql_both_expressions() {
740 let policy = Policy::new(make_ident("update_own"), make_ident("User"), make_span())
741 .with_commands(vec![PolicyCommand::Update])
742 .with_using("id = current_user_id()")
743 .with_check("id = current_user_id()");
744
745 let sql = policy.to_sql("users");
746 assert!(sql.contains("USING (id = current_user_id())"));
747 assert!(sql.contains("WITH CHECK (id = current_user_id())"));
748 }
749
750 #[test]
751 fn test_policy_equality() {
752 let policy1 = Policy::new(make_ident("test"), make_ident("User"), make_span());
753 let policy2 = Policy::new(make_ident("test"), make_ident("User"), make_span());
754
755 assert_eq!(policy1, policy2);
756 }
757
758 #[test]
759 fn test_policy_clone() {
760 let policy = Policy::new(make_ident("original"), make_ident("User"), make_span())
761 .with_using("id = 1");
762
763 let cloned = policy.clone();
764 assert_eq!(cloned.name(), "original");
765 assert_eq!(cloned.using_expr, Some("id = 1".to_string()));
766 }
767
768 #[test]
771 fn test_policy_type_from_str() {
772 assert_eq!(
773 PolicyType::from_str("PERMISSIVE"),
774 Some(PolicyType::Permissive)
775 );
776 assert_eq!(
777 PolicyType::from_str("permissive"),
778 Some(PolicyType::Permissive)
779 );
780 assert_eq!(
781 PolicyType::from_str("Permissive"),
782 Some(PolicyType::Permissive)
783 );
784 assert_eq!(
785 PolicyType::from_str("RESTRICTIVE"),
786 Some(PolicyType::Restrictive)
787 );
788 assert_eq!(
789 PolicyType::from_str("restrictive"),
790 Some(PolicyType::Restrictive)
791 );
792 assert_eq!(PolicyType::from_str("invalid"), None);
793 }
794
795 #[test]
796 fn test_policy_type_as_str() {
797 assert_eq!(PolicyType::Permissive.as_str(), "PERMISSIVE");
798 assert_eq!(PolicyType::Restrictive.as_str(), "RESTRICTIVE");
799 }
800
801 #[test]
802 fn test_policy_type_display() {
803 assert_eq!(format!("{}", PolicyType::Permissive), "PERMISSIVE");
804 assert_eq!(format!("{}", PolicyType::Restrictive), "RESTRICTIVE");
805 }
806
807 #[test]
808 fn test_policy_type_default() {
809 let policy_type: PolicyType = Default::default();
810 assert_eq!(policy_type, PolicyType::Permissive);
811 }
812
813 #[test]
814 fn test_policy_type_equality() {
815 assert_eq!(PolicyType::Permissive, PolicyType::Permissive);
816 assert_eq!(PolicyType::Restrictive, PolicyType::Restrictive);
817 assert_ne!(PolicyType::Permissive, PolicyType::Restrictive);
818 }
819
820 #[test]
823 fn test_policy_command_from_str() {
824 assert_eq!(PolicyCommand::from_str("ALL"), Some(PolicyCommand::All));
825 assert_eq!(PolicyCommand::from_str("all"), Some(PolicyCommand::All));
826 assert_eq!(
827 PolicyCommand::from_str("SELECT"),
828 Some(PolicyCommand::Select)
829 );
830 assert_eq!(
831 PolicyCommand::from_str("select"),
832 Some(PolicyCommand::Select)
833 );
834 assert_eq!(
835 PolicyCommand::from_str("INSERT"),
836 Some(PolicyCommand::Insert)
837 );
838 assert_eq!(
839 PolicyCommand::from_str("UPDATE"),
840 Some(PolicyCommand::Update)
841 );
842 assert_eq!(
843 PolicyCommand::from_str("DELETE"),
844 Some(PolicyCommand::Delete)
845 );
846 assert_eq!(PolicyCommand::from_str("invalid"), None);
847 }
848
849 #[test]
850 fn test_policy_command_as_str() {
851 assert_eq!(PolicyCommand::All.as_str(), "ALL");
852 assert_eq!(PolicyCommand::Select.as_str(), "SELECT");
853 assert_eq!(PolicyCommand::Insert.as_str(), "INSERT");
854 assert_eq!(PolicyCommand::Update.as_str(), "UPDATE");
855 assert_eq!(PolicyCommand::Delete.as_str(), "DELETE");
856 }
857
858 #[test]
859 fn test_policy_command_display() {
860 assert_eq!(format!("{}", PolicyCommand::All), "ALL");
861 assert_eq!(format!("{}", PolicyCommand::Select), "SELECT");
862 assert_eq!(format!("{}", PolicyCommand::Insert), "INSERT");
863 assert_eq!(format!("{}", PolicyCommand::Update), "UPDATE");
864 assert_eq!(format!("{}", PolicyCommand::Delete), "DELETE");
865 }
866
867 #[test]
868 fn test_policy_command_requires_using() {
869 assert!(PolicyCommand::All.requires_using());
870 assert!(PolicyCommand::Select.requires_using());
871 assert!(PolicyCommand::Update.requires_using());
872 assert!(PolicyCommand::Delete.requires_using());
873 assert!(!PolicyCommand::Insert.requires_using());
874 }
875
876 #[test]
877 fn test_policy_command_requires_check() {
878 assert!(PolicyCommand::All.requires_check());
879 assert!(PolicyCommand::Insert.requires_check());
880 assert!(PolicyCommand::Update.requires_check());
881 assert!(!PolicyCommand::Select.requires_check());
882 assert!(!PolicyCommand::Delete.requires_check());
883 }
884
885 #[test]
886 fn test_policy_command_equality() {
887 assert_eq!(PolicyCommand::Select, PolicyCommand::Select);
888 assert_ne!(PolicyCommand::Select, PolicyCommand::Insert);
889 }
890
891 #[test]
894 fn test_policy_rls_scenario_user_isolation() {
895 let policy = Policy::new(
897 make_ident("user_isolation"),
898 make_ident("User"),
899 make_span(),
900 )
901 .with_type(PolicyType::Permissive)
902 .with_commands(vec![PolicyCommand::All])
903 .with_roles(vec!["authenticated".into()])
904 .with_using("id = auth.uid()")
905 .with_check("id = auth.uid()");
906
907 assert!(policy.is_permissive());
908 assert!(policy.applies_to(PolicyCommand::Select));
909 assert!(policy.applies_to(PolicyCommand::Insert));
910 assert!(policy.applies_to(PolicyCommand::Update));
911 assert!(policy.applies_to(PolicyCommand::Delete));
912
913 let sql = policy.to_sql("users");
914 assert!(sql.contains("auth.uid()"));
915 }
916
917 #[test]
918 fn test_policy_rls_scenario_org_based() {
919 let policy = Policy::new(
921 make_ident("org_access"),
922 make_ident("Document"),
923 make_span(),
924 )
925 .with_type(PolicyType::Restrictive)
926 .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update])
927 .with_using("org_id = current_setting('app.current_org')::uuid");
928
929 assert!(policy.is_restrictive());
930 assert!(policy.applies_to(PolicyCommand::Select));
931 assert!(policy.applies_to(PolicyCommand::Update));
932 assert!(!policy.applies_to(PolicyCommand::Delete));
933
934 let sql = policy.to_sql("documents");
935 assert!(sql.contains("AS RESTRICTIVE"));
936 assert!(sql.contains("current_setting"));
937 }
938
939 #[test]
940 fn test_policy_rls_scenario_public_read() {
941 let read_policy = Policy::new(make_ident("public_read"), make_ident("Post"), make_span())
943 .with_commands(vec![PolicyCommand::Select])
944 .with_using("published = true OR author_id = current_user_id()");
945
946 let write_policy = Policy::new(make_ident("owner_write"), make_ident("Post"), make_span())
947 .with_commands(vec![PolicyCommand::Update, PolicyCommand::Delete])
948 .with_roles(vec!["authenticated".into()])
949 .with_using("author_id = current_user_id()");
950
951 assert_eq!(read_policy.effective_roles(), vec!["PUBLIC"]);
952 assert!(write_policy.effective_roles().contains(&"authenticated"));
953 }
954
955 #[test]
958 fn test_mssql_block_operation_from_str() {
959 assert_eq!(
960 MssqlBlockOperation::from_str("AFTER INSERT"),
961 Some(MssqlBlockOperation::AfterInsert)
962 );
963 assert_eq!(
964 MssqlBlockOperation::from_str("after_insert"),
965 Some(MssqlBlockOperation::AfterInsert)
966 );
967 assert_eq!(
968 MssqlBlockOperation::from_str("AFTERINSERT"),
969 Some(MssqlBlockOperation::AfterInsert)
970 );
971 assert_eq!(
972 MssqlBlockOperation::from_str("AFTER UPDATE"),
973 Some(MssqlBlockOperation::AfterUpdate)
974 );
975 assert_eq!(
976 MssqlBlockOperation::from_str("BEFORE UPDATE"),
977 Some(MssqlBlockOperation::BeforeUpdate)
978 );
979 assert_eq!(
980 MssqlBlockOperation::from_str("BEFORE DELETE"),
981 Some(MssqlBlockOperation::BeforeDelete)
982 );
983 assert_eq!(MssqlBlockOperation::from_str("invalid"), None);
984 }
985
986 #[test]
987 fn test_mssql_block_operation_as_str() {
988 assert_eq!(MssqlBlockOperation::AfterInsert.as_str(), "AFTER INSERT");
989 assert_eq!(MssqlBlockOperation::AfterUpdate.as_str(), "AFTER UPDATE");
990 assert_eq!(MssqlBlockOperation::BeforeUpdate.as_str(), "BEFORE UPDATE");
991 assert_eq!(MssqlBlockOperation::BeforeDelete.as_str(), "BEFORE DELETE");
992 }
993
994 #[test]
995 fn test_mssql_block_operation_display() {
996 assert_eq!(
997 format!("{}", MssqlBlockOperation::AfterInsert),
998 "AFTER INSERT"
999 );
1000 assert_eq!(
1001 format!("{}", MssqlBlockOperation::BeforeDelete),
1002 "BEFORE DELETE"
1003 );
1004 }
1005
1006 #[test]
1009 fn test_policy_mssql_schema_default() {
1010 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1011 assert_eq!(policy.mssql_schema(), "Security");
1012 }
1013
1014 #[test]
1015 fn test_policy_with_mssql_schema() {
1016 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1017 .with_mssql_schema("RLS");
1018 assert_eq!(policy.mssql_schema(), "RLS");
1019 }
1020
1021 #[test]
1022 fn test_policy_mssql_predicate_function_name() {
1023 let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span());
1024 assert_eq!(
1025 policy.mssql_predicate_function_name(),
1026 "fn_user_filter_predicate"
1027 );
1028 }
1029
1030 #[test]
1031 fn test_policy_with_mssql_block_operations() {
1032 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1033 .with_mssql_block_operations(vec![
1034 MssqlBlockOperation::AfterInsert,
1035 MssqlBlockOperation::AfterUpdate,
1036 ]);
1037
1038 assert_eq!(policy.mssql_block_operations.len(), 2);
1039 }
1040
1041 #[test]
1042 fn test_policy_add_mssql_block_operation() {
1043 let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1044
1045 policy.add_mssql_block_operation(MssqlBlockOperation::AfterInsert);
1046 policy.add_mssql_block_operation(MssqlBlockOperation::BeforeDelete);
1047
1048 assert_eq!(policy.mssql_block_operations.len(), 2);
1049 }
1050
1051 #[test]
1052 fn test_policy_to_mssql_sql_simple() {
1053 let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span())
1054 .with_commands(vec![PolicyCommand::Select])
1055 .with_using("UserId = @UserId");
1056
1057 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1058
1059 assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1061
1062 assert!(
1064 mssql
1065 .function_sql
1066 .contains("CREATE FUNCTION Security.fn_user_filter_predicate")
1067 );
1068 assert!(mssql.function_sql.contains("@UserId AS INT"));
1069 assert!(mssql.function_sql.contains("WITH SCHEMABINDING"));
1070 assert!(mssql.function_sql.contains("RETURNS TABLE"));
1071
1072 assert!(
1074 mssql
1075 .policy_sql
1076 .contains("CREATE SECURITY POLICY Security.user_filter")
1077 );
1078 assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1079 assert!(mssql.policy_sql.contains("ON dbo.Users"));
1080 assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1081 }
1082
1083 #[test]
1084 fn test_policy_to_mssql_sql_with_check() {
1085 let policy = Policy::new(make_ident("user_insert"), make_ident("User"), make_span())
1086 .with_commands(vec![PolicyCommand::Insert])
1087 .with_check("UserId = @UserId");
1088
1089 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1090
1091 assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1092 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1093 }
1094
1095 #[test]
1096 fn test_policy_to_mssql_sql_with_both() {
1097 let policy = Policy::new(make_ident("user_all"), make_ident("User"), make_span())
1098 .with_commands(vec![PolicyCommand::All])
1099 .with_using("UserId = @UserId")
1100 .with_check("UserId = @UserId");
1101
1102 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1103
1104 assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1105 assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1106 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1107 assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1108 assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1109 assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1110 }
1111
1112 #[test]
1113 fn test_policy_to_mssql_sql_custom_schema() {
1114 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1115 .with_mssql_schema("RLS")
1116 .with_using("UserId = @UserId");
1117
1118 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1119
1120 assert!(mssql.schema_sql.contains("CREATE SCHEMA RLS"));
1121 assert!(mssql.function_sql.contains("RLS.fn_test_predicate"));
1122 assert!(mssql.policy_sql.contains("RLS.test"));
1123 }
1124
1125 #[test]
1126 fn test_policy_to_mssql_sql_translates_postgres_functions() {
1127 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1128 .with_using("id = current_user_id()");
1129
1130 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1131
1132 assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1133 assert!(!mssql.function_sql.contains("current_user_id"));
1134 }
1135
1136 #[test]
1137 fn test_policy_to_mssql_sql_translates_auth_uid() {
1138 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1139 .with_using("id = auth.uid()");
1140
1141 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1142
1143 assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1144 assert!(!mssql.function_sql.contains("auth.uid"));
1145 }
1146
1147 #[test]
1148 fn test_mssql_policy_statements_all_statements() {
1149 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1150 .with_using("UserId = @UserId");
1151
1152 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1153 let statements = mssql.all_statements();
1154
1155 assert_eq!(statements.len(), 3);
1156 }
1157
1158 #[test]
1159 fn test_mssql_policy_statements_to_sql() {
1160 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1161 .with_using("UserId = @UserId");
1162
1163 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1164 let full_sql = mssql.to_sql();
1165
1166 assert!(full_sql.contains("GO"));
1167 assert!(full_sql.contains("CREATE SCHEMA"));
1168 assert!(full_sql.contains("CREATE FUNCTION"));
1169 assert!(full_sql.contains("CREATE SECURITY POLICY"));
1170 }
1171
1172 #[test]
1173 fn test_policy_default_mssql_block_operations() {
1174 let insert_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1176 .with_commands(vec![PolicyCommand::Insert])
1177 .with_check("true");
1178
1179 let mssql = insert_policy.to_mssql_sql("dbo.Users", "UserId");
1180 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1181 assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1182
1183 let update_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1185 .with_commands(vec![PolicyCommand::Update])
1186 .with_check("true");
1187
1188 let mssql = update_policy.to_mssql_sql("dbo.Users", "UserId");
1189 assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1190 assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1191 assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1192
1193 let delete_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1195 .with_commands(vec![PolicyCommand::Delete])
1196 .with_check("true");
1197
1198 let mssql = delete_policy.to_mssql_sql("dbo.Users", "UserId");
1199 assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1200 assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1201 }
1202
1203 #[test]
1204 fn test_policy_mssql_custom_block_operations() {
1205 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1206 .with_commands(vec![PolicyCommand::All])
1207 .with_check("true")
1208 .with_mssql_block_operations(vec![MssqlBlockOperation::AfterInsert]);
1209
1210 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1211
1212 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1214 assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1215 assert!(!mssql.policy_sql.contains("AFTER UPDATE"));
1216 }
1217
1218 #[test]
1221 fn test_mssql_rls_scenario_user_isolation() {
1222 let policy = Policy::new(
1223 make_ident("user_isolation"),
1224 make_ident("User"),
1225 make_span(),
1226 )
1227 .with_mssql_schema("Security")
1228 .with_commands(vec![PolicyCommand::All])
1229 .with_using("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)")
1230 .with_check("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)");
1231
1232 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1233
1234 assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1236 assert!(mssql.function_sql.contains("fn_user_isolation_predicate"));
1237 assert!(mssql.policy_sql.contains("user_isolation"));
1238 assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1239 }
1240
1241 #[test]
1242 fn test_mssql_rls_scenario_multi_tenant() {
1243 let policy = Policy::new(
1244 make_ident("tenant_isolation"),
1245 make_ident("Order"),
1246 make_span(),
1247 )
1248 .with_mssql_schema("MultiTenant")
1249 .with_using("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)")
1250 .with_check("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)");
1251
1252 let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
1253
1254 assert!(mssql.schema_sql.contains("MultiTenant"));
1255 assert!(mssql.function_sql.contains("@TenantId AS INT"));
1256 assert!(mssql.policy_sql.contains("dbo.Orders"));
1257 }
1258}
1259