1use serde::{Deserialize, Serialize};
26use smol_str::SmolStr;
27
28use super::{Documentation, Ident, Span};
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub struct Policy {
58 pub name: Ident,
60 pub table: Ident,
62 pub policy_type: PolicyType,
64 pub commands: Vec<PolicyCommand>,
66 pub roles: Vec<SmolStr>,
68 pub using_expr: Option<String>,
72 pub check_expr: Option<String>,
76 pub mssql_schema: Option<SmolStr>,
78 pub mssql_block_operations: Vec<MssqlBlockOperation>,
80 pub documentation: Option<Documentation>,
82 pub span: Span,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub source_id: Option<crate::loader::SourceId>,
87}
88
89impl Policy {
90 pub fn new(name: Ident, table: Ident, span: Span) -> Self {
92 Self {
93 name,
94 table,
95 policy_type: PolicyType::Permissive,
96 commands: vec![PolicyCommand::All],
97 roles: vec![],
98 using_expr: None,
99 check_expr: None,
100 mssql_schema: None,
101 mssql_block_operations: vec![],
102 documentation: None,
103 source_id: None,
104 span,
105 }
106 }
107
108 pub fn name(&self) -> &str {
110 self.name.as_str()
111 }
112
113 pub fn table(&self) -> &str {
115 self.table.as_str()
116 }
117
118 pub fn with_type(mut self, policy_type: PolicyType) -> Self {
120 self.policy_type = policy_type;
121 self
122 }
123
124 pub fn with_commands(mut self, commands: Vec<PolicyCommand>) -> Self {
126 self.commands = commands;
127 self
128 }
129
130 pub fn add_command(&mut self, command: PolicyCommand) {
132 self.commands.push(command);
133 }
134
135 pub fn with_roles(mut self, roles: Vec<SmolStr>) -> Self {
137 self.roles = roles;
138 self
139 }
140
141 pub fn add_role(&mut self, role: impl Into<SmolStr>) {
143 self.roles.push(role.into());
144 }
145
146 pub fn with_using(mut self, expr: impl Into<String>) -> Self {
148 self.using_expr = Some(expr.into());
149 self
150 }
151
152 pub fn with_check(mut self, expr: impl Into<String>) -> Self {
154 self.check_expr = Some(expr.into());
155 self
156 }
157
158 pub fn with_documentation(mut self, doc: Documentation) -> Self {
160 self.documentation = Some(doc);
161 self
162 }
163
164 pub fn with_mssql_schema(mut self, schema: impl Into<SmolStr>) -> Self {
166 self.mssql_schema = Some(schema.into());
167 self
168 }
169
170 pub fn with_mssql_block_operations(mut self, operations: Vec<MssqlBlockOperation>) -> Self {
172 self.mssql_block_operations = operations;
173 self
174 }
175
176 pub fn add_mssql_block_operation(&mut self, operation: MssqlBlockOperation) {
178 self.mssql_block_operations.push(operation);
179 }
180
181 pub fn applies_to(&self, command: PolicyCommand) -> bool {
183 self.commands.contains(&PolicyCommand::All) || self.commands.contains(&command)
184 }
185
186 pub fn is_restrictive(&self) -> bool {
188 self.policy_type == PolicyType::Restrictive
189 }
190
191 pub fn is_permissive(&self) -> bool {
193 self.policy_type == PolicyType::Permissive
194 }
195
196 pub fn effective_roles(&self) -> Vec<&str> {
198 if self.roles.is_empty() {
199 vec!["PUBLIC"]
200 } else {
201 self.roles.iter().map(|r| r.as_str()).collect()
202 }
203 }
204
205 pub fn mssql_schema(&self) -> &str {
207 self.mssql_schema
208 .as_ref()
209 .map(|s| s.as_str())
210 .unwrap_or("Security")
211 }
212
213 pub fn mssql_predicate_function_name(&self) -> String {
215 format!("fn_{}_predicate", self.name())
216 }
217
218 pub fn to_sql(&self, table_name: &str) -> String {
220 self.to_postgres_sql(table_name)
221 }
222
223 pub fn to_postgres_sql(&self, table_name: &str) -> String {
225 let mut sql = format!("CREATE POLICY {} ON {}", self.name(), table_name);
226
227 match self.policy_type {
229 PolicyType::Permissive => {} PolicyType::Restrictive => sql.push_str(" AS RESTRICTIVE"),
231 }
232
233 if !self.commands.is_empty() && !self.commands.contains(&PolicyCommand::All) {
235 let cmds: Vec<&str> = self.commands.iter().map(|c| c.as_str()).collect();
236 sql.push_str(&format!(" FOR {}", cmds[0]));
238 }
239
240 let roles = self.effective_roles();
242 sql.push_str(&format!(" TO {}", roles.join(", ")));
243
244 if let Some(ref using) = self.using_expr {
246 sql.push_str(&format!(" USING ({})", using));
247 }
248
249 if let Some(ref check) = self.check_expr {
251 sql.push_str(&format!(" WITH CHECK ({})", check));
252 }
253
254 sql
255 }
256
257 pub fn to_mssql_sql(&self, table_name: &str, predicate_column: &str) -> MssqlPolicyStatements {
274 let schema = self.mssql_schema();
275 let func_name = self.mssql_predicate_function_name();
276
277 let filter_expr = self
279 .using_expr
280 .as_deref()
281 .unwrap_or("1 = 1")
282 .replace(
283 "current_user_id()",
284 "CAST(SESSION_CONTEXT(N'UserId') AS INT)",
285 )
286 .replace("auth.uid()", "CAST(SESSION_CONTEXT(N'UserId') AS INT)")
287 .replace(
288 "current_setting('app.current_org')",
289 "SESSION_CONTEXT(N'OrgId')",
290 );
291
292 let function_sql = format!(
293 r#"CREATE FUNCTION {schema}.{func_name}(@{predicate_column} AS INT)
294 RETURNS TABLE
295WITH SCHEMABINDING
296AS
297 RETURN SELECT 1 AS fn_securitypredicate_result
298 WHERE {filter_expr}"#,
299 schema = schema,
300 func_name = func_name,
301 predicate_column = predicate_column,
302 filter_expr = filter_expr
303 );
304
305 let mut policy_sql = format!(
307 "CREATE SECURITY POLICY {schema}.{policy_name}\n",
308 schema = schema,
309 policy_name = self.name()
310 );
311
312 if self.using_expr.is_some() {
314 policy_sql.push_str(&format!(
315 "ADD FILTER PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name}",
316 schema = schema,
317 func_name = func_name,
318 predicate_column = predicate_column,
319 table_name = table_name
320 ));
321 }
322
323 if self.check_expr.is_some() {
325 let block_ops = if self.mssql_block_operations.is_empty() {
326 self.default_mssql_block_operations()
328 } else {
329 self.mssql_block_operations.clone()
330 };
331
332 for (i, op) in block_ops.iter().enumerate() {
333 if i > 0 || self.using_expr.is_some() {
334 policy_sql.push_str(",\n");
335 }
336 policy_sql.push_str(&format!(
337 "ADD BLOCK PREDICATE {schema}.{func_name}({predicate_column}) ON {table_name} {op}",
338 schema = schema,
339 func_name = func_name,
340 predicate_column = predicate_column,
341 table_name = table_name,
342 op = op.as_str()
343 ));
344 }
345 }
346
347 policy_sql.push_str("\nWITH (STATE = ON)");
348
349 MssqlPolicyStatements {
350 schema_sql: format!("CREATE SCHEMA {schema}"),
351 function_sql,
352 policy_sql,
353 }
354 }
355
356 fn default_mssql_block_operations(&self) -> Vec<MssqlBlockOperation> {
358 let mut ops = vec![];
359
360 if self.applies_to(PolicyCommand::Insert) {
361 ops.push(MssqlBlockOperation::AfterInsert);
362 }
363 if self.applies_to(PolicyCommand::Update) {
364 ops.push(MssqlBlockOperation::AfterUpdate);
365 ops.push(MssqlBlockOperation::BeforeUpdate);
366 }
367 if self.applies_to(PolicyCommand::Delete) {
368 ops.push(MssqlBlockOperation::BeforeDelete);
369 }
370
371 ops
372 }
373}
374
375#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
377pub enum PolicyType {
378 #[default]
381 Permissive,
382 Restrictive,
385}
386
387impl PolicyType {
388 #[allow(clippy::should_implement_trait)]
390 pub fn from_str(s: &str) -> Option<Self> {
391 match s.to_uppercase().as_str() {
392 "PERMISSIVE" => Some(Self::Permissive),
393 "RESTRICTIVE" => Some(Self::Restrictive),
394 _ => None,
395 }
396 }
397
398 pub fn as_str(&self) -> &'static str {
400 match self {
401 Self::Permissive => "PERMISSIVE",
402 Self::Restrictive => "RESTRICTIVE",
403 }
404 }
405}
406
407impl std::fmt::Display for PolicyType {
408 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 write!(f, "{}", self.as_str())
410 }
411}
412
413#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
415pub enum PolicyCommand {
416 All,
418 Select,
420 Insert,
422 Update,
424 Delete,
426}
427
428impl PolicyCommand {
429 #[allow(clippy::should_implement_trait)]
431 pub fn from_str(s: &str) -> Option<Self> {
432 match s.to_uppercase().as_str() {
433 "ALL" => Some(Self::All),
434 "SELECT" => Some(Self::Select),
435 "INSERT" => Some(Self::Insert),
436 "UPDATE" => Some(Self::Update),
437 "DELETE" => Some(Self::Delete),
438 _ => None,
439 }
440 }
441
442 pub fn as_str(&self) -> &'static str {
444 match self {
445 Self::All => "ALL",
446 Self::Select => "SELECT",
447 Self::Insert => "INSERT",
448 Self::Update => "UPDATE",
449 Self::Delete => "DELETE",
450 }
451 }
452
453 pub fn requires_using(&self) -> bool {
455 matches!(self, Self::All | Self::Select | Self::Update | Self::Delete)
456 }
457
458 pub fn requires_check(&self) -> bool {
460 matches!(self, Self::All | Self::Insert | Self::Update)
461 }
462}
463
464impl std::fmt::Display for PolicyCommand {
465 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466 write!(f, "{}", self.as_str())
467 }
468}
469
470#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
475pub enum MssqlBlockOperation {
476 AfterInsert,
479 AfterUpdate,
482 BeforeUpdate,
485 BeforeDelete,
488}
489
490impl MssqlBlockOperation {
491 #[allow(clippy::should_implement_trait)]
493 pub fn from_str(s: &str) -> Option<Self> {
494 match s.to_uppercase().replace([' ', '_'], "").as_str() {
495 "AFTERINSERT" => Some(Self::AfterInsert),
496 "AFTERUPDATE" => Some(Self::AfterUpdate),
497 "BEFOREUPDATE" => Some(Self::BeforeUpdate),
498 "BEFOREDELETE" => Some(Self::BeforeDelete),
499 _ => None,
500 }
501 }
502
503 pub fn as_str(&self) -> &'static str {
505 match self {
506 Self::AfterInsert => "AFTER INSERT",
507 Self::AfterUpdate => "AFTER UPDATE",
508 Self::BeforeUpdate => "BEFORE UPDATE",
509 Self::BeforeDelete => "BEFORE DELETE",
510 }
511 }
512}
513
514impl std::fmt::Display for MssqlBlockOperation {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 write!(f, "{}", self.as_str())
517 }
518}
519
520#[derive(Debug, Clone, PartialEq)]
522pub struct MssqlPolicyStatements {
523 pub schema_sql: String,
525 pub function_sql: String,
527 pub policy_sql: String,
529}
530
531impl MssqlPolicyStatements {
532 pub fn all_statements(&self) -> Vec<&str> {
534 vec![&self.schema_sql, &self.function_sql, &self.policy_sql]
535 }
536
537 pub fn to_sql(&self) -> String {
539 format!(
540 "{schema_sql};\nGO\n\n{function_sql};\nGO\n\n{policy_sql};",
541 schema_sql = self.schema_sql,
542 function_sql = self.function_sql,
543 policy_sql = self.policy_sql
544 )
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 fn make_span() -> Span {
553 Span::new(0, 10)
554 }
555
556 fn make_ident(name: &str) -> Ident {
557 Ident::new(name, make_span())
558 }
559
560 #[test]
563 fn test_policy_new() {
564 let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span());
565
566 assert_eq!(policy.name(), "read_own");
567 assert_eq!(policy.table(), "User");
568 assert_eq!(policy.policy_type, PolicyType::Permissive);
569 assert_eq!(policy.commands, vec![PolicyCommand::All]);
570 assert!(policy.roles.is_empty());
571 assert!(policy.using_expr.is_none());
572 assert!(policy.check_expr.is_none());
573 assert!(policy.documentation.is_none());
574 }
575
576 #[test]
577 fn test_policy_with_type() {
578 let policy = Policy::new(make_ident("strict"), make_ident("User"), make_span())
579 .with_type(PolicyType::Restrictive);
580
581 assert!(policy.is_restrictive());
582 assert!(!policy.is_permissive());
583 }
584
585 #[test]
586 fn test_policy_with_commands() {
587 let policy = Policy::new(make_ident("read"), make_ident("User"), make_span())
588 .with_commands(vec![PolicyCommand::Select]);
589
590 assert!(policy.applies_to(PolicyCommand::Select));
591 assert!(!policy.applies_to(PolicyCommand::Insert));
592 assert!(!policy.applies_to(PolicyCommand::Update));
593 assert!(!policy.applies_to(PolicyCommand::Delete));
594 }
595
596 #[test]
597 fn test_policy_with_multiple_commands() {
598 let policy = Policy::new(make_ident("read_update"), make_ident("User"), make_span())
599 .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update]);
600
601 assert!(policy.applies_to(PolicyCommand::Select));
602 assert!(policy.applies_to(PolicyCommand::Update));
603 assert!(!policy.applies_to(PolicyCommand::Insert));
604 assert!(!policy.applies_to(PolicyCommand::Delete));
605 }
606
607 #[test]
608 fn test_policy_all_command_applies_to_all() {
609 let policy = Policy::new(make_ident("all"), make_ident("User"), make_span())
610 .with_commands(vec![PolicyCommand::All]);
611
612 assert!(policy.applies_to(PolicyCommand::Select));
613 assert!(policy.applies_to(PolicyCommand::Insert));
614 assert!(policy.applies_to(PolicyCommand::Update));
615 assert!(policy.applies_to(PolicyCommand::Delete));
616 assert!(policy.applies_to(PolicyCommand::All));
617 }
618
619 #[test]
620 fn test_policy_add_command() {
621 let mut policy =
622 Policy::new(make_ident("test"), make_ident("User"), make_span()).with_commands(vec![]);
623
624 policy.add_command(PolicyCommand::Select);
625 policy.add_command(PolicyCommand::Update);
626
627 assert_eq!(policy.commands.len(), 2);
628 assert!(policy.applies_to(PolicyCommand::Select));
629 assert!(policy.applies_to(PolicyCommand::Update));
630 }
631
632 #[test]
633 fn test_policy_with_roles() {
634 let policy = Policy::new(make_ident("auth"), make_ident("User"), make_span())
635 .with_roles(vec!["authenticated".into(), "admin".into()]);
636
637 assert_eq!(policy.roles.len(), 2);
638 let roles = policy.effective_roles();
639 assert!(roles.contains(&"authenticated"));
640 assert!(roles.contains(&"admin"));
641 }
642
643 #[test]
644 fn test_policy_add_role() {
645 let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
646
647 policy.add_role("user");
648 policy.add_role("moderator");
649
650 assert_eq!(policy.roles.len(), 2);
651 }
652
653 #[test]
654 fn test_policy_effective_roles_default() {
655 let policy = Policy::new(make_ident("public"), make_ident("User"), make_span());
656
657 let roles = policy.effective_roles();
658 assert_eq!(roles, vec!["PUBLIC"]);
659 }
660
661 #[test]
662 fn test_policy_with_using() {
663 let policy = Policy::new(make_ident("own"), make_ident("User"), make_span())
664 .with_using("user_id = current_user_id()");
665
666 assert_eq!(
667 policy.using_expr.as_deref(),
668 Some("user_id = current_user_id()")
669 );
670 }
671
672 #[test]
673 fn test_policy_with_check() {
674 let policy = Policy::new(make_ident("insert"), make_ident("User"), make_span())
675 .with_check("user_id = current_user_id()");
676
677 assert_eq!(
678 policy.check_expr.as_deref(),
679 Some("user_id = current_user_id()")
680 );
681 }
682
683 #[test]
684 fn test_policy_with_documentation() {
685 let policy =
686 Policy::new(make_ident("doc"), make_ident("User"), make_span()).with_documentation(
687 Documentation::new("Users can only see their own data", make_span()),
688 );
689
690 assert!(policy.documentation.is_some());
691 assert_eq!(
692 policy.documentation.unwrap().text,
693 "Users can only see their own data"
694 );
695 }
696
697 #[test]
698 fn test_policy_to_sql_simple() {
699 let policy = Policy::new(make_ident("read_own"), make_ident("User"), make_span())
700 .with_commands(vec![PolicyCommand::Select])
701 .with_using("id = current_user_id()");
702
703 let sql = policy.to_sql("users");
704 assert!(sql.contains("CREATE POLICY read_own ON users"));
705 assert!(sql.contains("FOR SELECT"));
706 assert!(sql.contains("TO PUBLIC"));
707 assert!(sql.contains("USING (id = current_user_id())"));
708 }
709
710 #[test]
711 fn test_policy_to_sql_with_roles() {
712 let policy = Policy::new(make_ident("auth_read"), make_ident("User"), make_span())
713 .with_commands(vec![PolicyCommand::Select])
714 .with_roles(vec!["authenticated".into()])
715 .with_using("true");
716
717 let sql = policy.to_sql("users");
718 assert!(sql.contains("TO authenticated"));
719 }
720
721 #[test]
722 fn test_policy_to_sql_restrictive() {
723 let policy = Policy::new(make_ident("restrict"), make_ident("User"), make_span())
724 .with_type(PolicyType::Restrictive)
725 .with_using("org_id = current_org_id()");
726
727 let sql = policy.to_sql("users");
728 assert!(sql.contains("AS RESTRICTIVE"));
729 }
730
731 #[test]
732 fn test_policy_to_sql_with_check() {
733 let policy = Policy::new(make_ident("insert_own"), make_ident("User"), make_span())
734 .with_commands(vec![PolicyCommand::Insert])
735 .with_check("id = current_user_id()");
736
737 let sql = policy.to_sql("users");
738 assert!(sql.contains("FOR INSERT"));
739 assert!(sql.contains("WITH CHECK (id = current_user_id())"));
740 }
741
742 #[test]
743 fn test_policy_to_sql_both_expressions() {
744 let policy = Policy::new(make_ident("update_own"), make_ident("User"), make_span())
745 .with_commands(vec![PolicyCommand::Update])
746 .with_using("id = current_user_id()")
747 .with_check("id = current_user_id()");
748
749 let sql = policy.to_sql("users");
750 assert!(sql.contains("USING (id = current_user_id())"));
751 assert!(sql.contains("WITH CHECK (id = current_user_id())"));
752 }
753
754 #[test]
755 fn test_policy_equality() {
756 let policy1 = Policy::new(make_ident("test"), make_ident("User"), make_span());
757 let policy2 = Policy::new(make_ident("test"), make_ident("User"), make_span());
758
759 assert_eq!(policy1, policy2);
760 }
761
762 #[test]
763 fn test_policy_clone() {
764 let policy = Policy::new(make_ident("original"), make_ident("User"), make_span())
765 .with_using("id = 1");
766
767 let cloned = policy.clone();
768 assert_eq!(cloned.name(), "original");
769 assert_eq!(cloned.using_expr, Some("id = 1".to_string()));
770 }
771
772 #[test]
775 fn test_policy_type_from_str() {
776 assert_eq!(
777 PolicyType::from_str("PERMISSIVE"),
778 Some(PolicyType::Permissive)
779 );
780 assert_eq!(
781 PolicyType::from_str("permissive"),
782 Some(PolicyType::Permissive)
783 );
784 assert_eq!(
785 PolicyType::from_str("Permissive"),
786 Some(PolicyType::Permissive)
787 );
788 assert_eq!(
789 PolicyType::from_str("RESTRICTIVE"),
790 Some(PolicyType::Restrictive)
791 );
792 assert_eq!(
793 PolicyType::from_str("restrictive"),
794 Some(PolicyType::Restrictive)
795 );
796 assert_eq!(PolicyType::from_str("invalid"), None);
797 }
798
799 #[test]
800 fn test_policy_type_as_str() {
801 assert_eq!(PolicyType::Permissive.as_str(), "PERMISSIVE");
802 assert_eq!(PolicyType::Restrictive.as_str(), "RESTRICTIVE");
803 }
804
805 #[test]
806 fn test_policy_type_display() {
807 assert_eq!(format!("{}", PolicyType::Permissive), "PERMISSIVE");
808 assert_eq!(format!("{}", PolicyType::Restrictive), "RESTRICTIVE");
809 }
810
811 #[test]
812 fn test_policy_type_default() {
813 let policy_type: PolicyType = Default::default();
814 assert_eq!(policy_type, PolicyType::Permissive);
815 }
816
817 #[test]
818 fn test_policy_type_equality() {
819 assert_eq!(PolicyType::Permissive, PolicyType::Permissive);
820 assert_eq!(PolicyType::Restrictive, PolicyType::Restrictive);
821 assert_ne!(PolicyType::Permissive, PolicyType::Restrictive);
822 }
823
824 #[test]
827 fn test_policy_command_from_str() {
828 assert_eq!(PolicyCommand::from_str("ALL"), Some(PolicyCommand::All));
829 assert_eq!(PolicyCommand::from_str("all"), Some(PolicyCommand::All));
830 assert_eq!(
831 PolicyCommand::from_str("SELECT"),
832 Some(PolicyCommand::Select)
833 );
834 assert_eq!(
835 PolicyCommand::from_str("select"),
836 Some(PolicyCommand::Select)
837 );
838 assert_eq!(
839 PolicyCommand::from_str("INSERT"),
840 Some(PolicyCommand::Insert)
841 );
842 assert_eq!(
843 PolicyCommand::from_str("UPDATE"),
844 Some(PolicyCommand::Update)
845 );
846 assert_eq!(
847 PolicyCommand::from_str("DELETE"),
848 Some(PolicyCommand::Delete)
849 );
850 assert_eq!(PolicyCommand::from_str("invalid"), None);
851 }
852
853 #[test]
854 fn test_policy_command_as_str() {
855 assert_eq!(PolicyCommand::All.as_str(), "ALL");
856 assert_eq!(PolicyCommand::Select.as_str(), "SELECT");
857 assert_eq!(PolicyCommand::Insert.as_str(), "INSERT");
858 assert_eq!(PolicyCommand::Update.as_str(), "UPDATE");
859 assert_eq!(PolicyCommand::Delete.as_str(), "DELETE");
860 }
861
862 #[test]
863 fn test_policy_command_display() {
864 assert_eq!(format!("{}", PolicyCommand::All), "ALL");
865 assert_eq!(format!("{}", PolicyCommand::Select), "SELECT");
866 assert_eq!(format!("{}", PolicyCommand::Insert), "INSERT");
867 assert_eq!(format!("{}", PolicyCommand::Update), "UPDATE");
868 assert_eq!(format!("{}", PolicyCommand::Delete), "DELETE");
869 }
870
871 #[test]
872 fn test_policy_command_requires_using() {
873 assert!(PolicyCommand::All.requires_using());
874 assert!(PolicyCommand::Select.requires_using());
875 assert!(PolicyCommand::Update.requires_using());
876 assert!(PolicyCommand::Delete.requires_using());
877 assert!(!PolicyCommand::Insert.requires_using());
878 }
879
880 #[test]
881 fn test_policy_command_requires_check() {
882 assert!(PolicyCommand::All.requires_check());
883 assert!(PolicyCommand::Insert.requires_check());
884 assert!(PolicyCommand::Update.requires_check());
885 assert!(!PolicyCommand::Select.requires_check());
886 assert!(!PolicyCommand::Delete.requires_check());
887 }
888
889 #[test]
890 fn test_policy_command_equality() {
891 assert_eq!(PolicyCommand::Select, PolicyCommand::Select);
892 assert_ne!(PolicyCommand::Select, PolicyCommand::Insert);
893 }
894
895 #[test]
898 fn test_policy_rls_scenario_user_isolation() {
899 let policy = Policy::new(
901 make_ident("user_isolation"),
902 make_ident("User"),
903 make_span(),
904 )
905 .with_type(PolicyType::Permissive)
906 .with_commands(vec![PolicyCommand::All])
907 .with_roles(vec!["authenticated".into()])
908 .with_using("id = auth.uid()")
909 .with_check("id = auth.uid()");
910
911 assert!(policy.is_permissive());
912 assert!(policy.applies_to(PolicyCommand::Select));
913 assert!(policy.applies_to(PolicyCommand::Insert));
914 assert!(policy.applies_to(PolicyCommand::Update));
915 assert!(policy.applies_to(PolicyCommand::Delete));
916
917 let sql = policy.to_sql("users");
918 assert!(sql.contains("auth.uid()"));
919 }
920
921 #[test]
922 fn test_policy_rls_scenario_org_based() {
923 let policy = Policy::new(
925 make_ident("org_access"),
926 make_ident("Document"),
927 make_span(),
928 )
929 .with_type(PolicyType::Restrictive)
930 .with_commands(vec![PolicyCommand::Select, PolicyCommand::Update])
931 .with_using("org_id = current_setting('app.current_org')::uuid");
932
933 assert!(policy.is_restrictive());
934 assert!(policy.applies_to(PolicyCommand::Select));
935 assert!(policy.applies_to(PolicyCommand::Update));
936 assert!(!policy.applies_to(PolicyCommand::Delete));
937
938 let sql = policy.to_sql("documents");
939 assert!(sql.contains("AS RESTRICTIVE"));
940 assert!(sql.contains("current_setting"));
941 }
942
943 #[test]
944 fn test_policy_rls_scenario_public_read() {
945 let read_policy = Policy::new(make_ident("public_read"), make_ident("Post"), make_span())
947 .with_commands(vec![PolicyCommand::Select])
948 .with_using("published = true OR author_id = current_user_id()");
949
950 let write_policy = Policy::new(make_ident("owner_write"), make_ident("Post"), make_span())
951 .with_commands(vec![PolicyCommand::Update, PolicyCommand::Delete])
952 .with_roles(vec!["authenticated".into()])
953 .with_using("author_id = current_user_id()");
954
955 assert_eq!(read_policy.effective_roles(), vec!["PUBLIC"]);
956 assert!(write_policy.effective_roles().contains(&"authenticated"));
957 }
958
959 #[test]
962 fn test_mssql_block_operation_from_str() {
963 assert_eq!(
964 MssqlBlockOperation::from_str("AFTER INSERT"),
965 Some(MssqlBlockOperation::AfterInsert)
966 );
967 assert_eq!(
968 MssqlBlockOperation::from_str("after_insert"),
969 Some(MssqlBlockOperation::AfterInsert)
970 );
971 assert_eq!(
972 MssqlBlockOperation::from_str("AFTERINSERT"),
973 Some(MssqlBlockOperation::AfterInsert)
974 );
975 assert_eq!(
976 MssqlBlockOperation::from_str("AFTER UPDATE"),
977 Some(MssqlBlockOperation::AfterUpdate)
978 );
979 assert_eq!(
980 MssqlBlockOperation::from_str("BEFORE UPDATE"),
981 Some(MssqlBlockOperation::BeforeUpdate)
982 );
983 assert_eq!(
984 MssqlBlockOperation::from_str("BEFORE DELETE"),
985 Some(MssqlBlockOperation::BeforeDelete)
986 );
987 assert_eq!(MssqlBlockOperation::from_str("invalid"), None);
988 }
989
990 #[test]
991 fn test_mssql_block_operation_as_str() {
992 assert_eq!(MssqlBlockOperation::AfterInsert.as_str(), "AFTER INSERT");
993 assert_eq!(MssqlBlockOperation::AfterUpdate.as_str(), "AFTER UPDATE");
994 assert_eq!(MssqlBlockOperation::BeforeUpdate.as_str(), "BEFORE UPDATE");
995 assert_eq!(MssqlBlockOperation::BeforeDelete.as_str(), "BEFORE DELETE");
996 }
997
998 #[test]
999 fn test_mssql_block_operation_display() {
1000 assert_eq!(
1001 format!("{}", MssqlBlockOperation::AfterInsert),
1002 "AFTER INSERT"
1003 );
1004 assert_eq!(
1005 format!("{}", MssqlBlockOperation::BeforeDelete),
1006 "BEFORE DELETE"
1007 );
1008 }
1009
1010 #[test]
1013 fn test_policy_mssql_schema_default() {
1014 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1015 assert_eq!(policy.mssql_schema(), "Security");
1016 }
1017
1018 #[test]
1019 fn test_policy_with_mssql_schema() {
1020 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1021 .with_mssql_schema("RLS");
1022 assert_eq!(policy.mssql_schema(), "RLS");
1023 }
1024
1025 #[test]
1026 fn test_policy_mssql_predicate_function_name() {
1027 let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span());
1028 assert_eq!(
1029 policy.mssql_predicate_function_name(),
1030 "fn_user_filter_predicate"
1031 );
1032 }
1033
1034 #[test]
1035 fn test_policy_with_mssql_block_operations() {
1036 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1037 .with_mssql_block_operations(vec![
1038 MssqlBlockOperation::AfterInsert,
1039 MssqlBlockOperation::AfterUpdate,
1040 ]);
1041
1042 assert_eq!(policy.mssql_block_operations.len(), 2);
1043 }
1044
1045 #[test]
1046 fn test_policy_add_mssql_block_operation() {
1047 let mut policy = Policy::new(make_ident("test"), make_ident("User"), make_span());
1048
1049 policy.add_mssql_block_operation(MssqlBlockOperation::AfterInsert);
1050 policy.add_mssql_block_operation(MssqlBlockOperation::BeforeDelete);
1051
1052 assert_eq!(policy.mssql_block_operations.len(), 2);
1053 }
1054
1055 #[test]
1056 fn test_policy_to_mssql_sql_simple() {
1057 let policy = Policy::new(make_ident("user_filter"), make_ident("User"), make_span())
1058 .with_commands(vec![PolicyCommand::Select])
1059 .with_using("UserId = @UserId");
1060
1061 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1062
1063 assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1065
1066 assert!(
1068 mssql
1069 .function_sql
1070 .contains("CREATE FUNCTION Security.fn_user_filter_predicate")
1071 );
1072 assert!(mssql.function_sql.contains("@UserId AS INT"));
1073 assert!(mssql.function_sql.contains("WITH SCHEMABINDING"));
1074 assert!(mssql.function_sql.contains("RETURNS TABLE"));
1075
1076 assert!(
1078 mssql
1079 .policy_sql
1080 .contains("CREATE SECURITY POLICY Security.user_filter")
1081 );
1082 assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1083 assert!(mssql.policy_sql.contains("ON dbo.Users"));
1084 assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1085 }
1086
1087 #[test]
1088 fn test_policy_to_mssql_sql_with_check() {
1089 let policy = Policy::new(make_ident("user_insert"), make_ident("User"), make_span())
1090 .with_commands(vec![PolicyCommand::Insert])
1091 .with_check("UserId = @UserId");
1092
1093 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1094
1095 assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1096 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1097 }
1098
1099 #[test]
1100 fn test_policy_to_mssql_sql_with_both() {
1101 let policy = Policy::new(make_ident("user_all"), make_ident("User"), make_span())
1102 .with_commands(vec![PolicyCommand::All])
1103 .with_using("UserId = @UserId")
1104 .with_check("UserId = @UserId");
1105
1106 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1107
1108 assert!(mssql.policy_sql.contains("FILTER PREDICATE"));
1109 assert!(mssql.policy_sql.contains("BLOCK PREDICATE"));
1110 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1111 assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1112 assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1113 assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1114 }
1115
1116 #[test]
1117 fn test_policy_to_mssql_sql_custom_schema() {
1118 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1119 .with_mssql_schema("RLS")
1120 .with_using("UserId = @UserId");
1121
1122 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1123
1124 assert!(mssql.schema_sql.contains("CREATE SCHEMA RLS"));
1125 assert!(mssql.function_sql.contains("RLS.fn_test_predicate"));
1126 assert!(mssql.policy_sql.contains("RLS.test"));
1127 }
1128
1129 #[test]
1130 fn test_policy_to_mssql_sql_translates_postgres_functions() {
1131 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1132 .with_using("id = current_user_id()");
1133
1134 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1135
1136 assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1137 assert!(!mssql.function_sql.contains("current_user_id"));
1138 }
1139
1140 #[test]
1141 fn test_policy_to_mssql_sql_translates_auth_uid() {
1142 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1143 .with_using("id = auth.uid()");
1144
1145 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1146
1147 assert!(mssql.function_sql.contains("SESSION_CONTEXT(N'UserId')"));
1148 assert!(!mssql.function_sql.contains("auth.uid"));
1149 }
1150
1151 #[test]
1152 fn test_mssql_policy_statements_all_statements() {
1153 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1154 .with_using("UserId = @UserId");
1155
1156 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1157 let statements = mssql.all_statements();
1158
1159 assert_eq!(statements.len(), 3);
1160 }
1161
1162 #[test]
1163 fn test_mssql_policy_statements_to_sql() {
1164 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1165 .with_using("UserId = @UserId");
1166
1167 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1168 let full_sql = mssql.to_sql();
1169
1170 assert!(full_sql.contains("GO"));
1171 assert!(full_sql.contains("CREATE SCHEMA"));
1172 assert!(full_sql.contains("CREATE FUNCTION"));
1173 assert!(full_sql.contains("CREATE SECURITY POLICY"));
1174 }
1175
1176 #[test]
1177 fn test_policy_default_mssql_block_operations() {
1178 let insert_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1180 .with_commands(vec![PolicyCommand::Insert])
1181 .with_check("true");
1182
1183 let mssql = insert_policy.to_mssql_sql("dbo.Users", "UserId");
1184 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1185 assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1186
1187 let update_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1189 .with_commands(vec![PolicyCommand::Update])
1190 .with_check("true");
1191
1192 let mssql = update_policy.to_mssql_sql("dbo.Users", "UserId");
1193 assert!(mssql.policy_sql.contains("AFTER UPDATE"));
1194 assert!(mssql.policy_sql.contains("BEFORE UPDATE"));
1195 assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1196
1197 let delete_policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1199 .with_commands(vec![PolicyCommand::Delete])
1200 .with_check("true");
1201
1202 let mssql = delete_policy.to_mssql_sql("dbo.Users", "UserId");
1203 assert!(mssql.policy_sql.contains("BEFORE DELETE"));
1204 assert!(!mssql.policy_sql.contains("AFTER INSERT"));
1205 }
1206
1207 #[test]
1208 fn test_policy_mssql_custom_block_operations() {
1209 let policy = Policy::new(make_ident("test"), make_ident("User"), make_span())
1210 .with_commands(vec![PolicyCommand::All])
1211 .with_check("true")
1212 .with_mssql_block_operations(vec![MssqlBlockOperation::AfterInsert]);
1213
1214 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1215
1216 assert!(mssql.policy_sql.contains("AFTER INSERT"));
1218 assert!(!mssql.policy_sql.contains("BEFORE DELETE"));
1219 assert!(!mssql.policy_sql.contains("AFTER UPDATE"));
1220 }
1221
1222 #[test]
1225 fn test_mssql_rls_scenario_user_isolation() {
1226 let policy = Policy::new(
1227 make_ident("user_isolation"),
1228 make_ident("User"),
1229 make_span(),
1230 )
1231 .with_mssql_schema("Security")
1232 .with_commands(vec![PolicyCommand::All])
1233 .with_using("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)")
1234 .with_check("UserId = CAST(SESSION_CONTEXT(N'UserId') AS INT)");
1235
1236 let mssql = policy.to_mssql_sql("dbo.Users", "UserId");
1237
1238 assert!(mssql.schema_sql.contains("CREATE SCHEMA Security"));
1240 assert!(mssql.function_sql.contains("fn_user_isolation_predicate"));
1241 assert!(mssql.policy_sql.contains("user_isolation"));
1242 assert!(mssql.policy_sql.contains("WITH (STATE = ON)"));
1243 }
1244
1245 #[test]
1246 fn test_mssql_rls_scenario_multi_tenant() {
1247 let policy = Policy::new(
1248 make_ident("tenant_isolation"),
1249 make_ident("Order"),
1250 make_span(),
1251 )
1252 .with_mssql_schema("MultiTenant")
1253 .with_using("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)")
1254 .with_check("TenantId = CAST(SESSION_CONTEXT(N'TenantId') AS INT)");
1255
1256 let mssql = policy.to_mssql_sql("dbo.Orders", "TenantId");
1257
1258 assert!(mssql.schema_sql.contains("MultiTenant"));
1259 assert!(mssql.function_sql.contains("@TenantId AS INT"));
1260 assert!(mssql.policy_sql.contains("dbo.Orders"));
1261 }
1262}