1use serde::{Deserialize, Serialize};
21
22use crate::error::{QueryError, QueryResult};
23use crate::sql::DatabaseType;
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct RlsPolicy {
32 pub name: String,
34 pub table: String,
36 pub command: PolicyCommand,
38 pub roles: Vec<String>,
40 pub using: Option<String>,
42 pub with_check: Option<String>,
44 pub permissive: bool,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50pub enum PolicyCommand {
51 All,
53 Select,
55 Insert,
57 Update,
59 Delete,
61}
62
63impl PolicyCommand {
64 pub fn to_sql(&self) -> &'static str {
66 match self {
67 Self::All => "ALL",
68 Self::Select => "SELECT",
69 Self::Insert => "INSERT",
70 Self::Update => "UPDATE",
71 Self::Delete => "DELETE",
72 }
73 }
74}
75
76impl RlsPolicy {
77 pub fn new(name: impl Into<String>, table: impl Into<String>) -> RlsPolicyBuilder {
79 RlsPolicyBuilder::new(name, table)
80 }
81
82 pub fn to_postgres_sql(&self) -> String {
84 let mut sql = format!(
85 "CREATE POLICY {} ON {} AS {} FOR {}",
86 self.name,
87 self.table,
88 if self.permissive {
89 "PERMISSIVE"
90 } else {
91 "RESTRICTIVE"
92 },
93 self.command.to_sql()
94 );
95
96 if !self.roles.is_empty() && self.roles != vec!["PUBLIC"] {
97 sql.push_str(" TO ");
98 sql.push_str(&self.roles.join(", "));
99 }
100
101 if let Some(ref using) = self.using {
102 sql.push_str(" USING (");
103 sql.push_str(using);
104 sql.push(')');
105 }
106
107 if let Some(ref check) = self.with_check {
108 sql.push_str(" WITH CHECK (");
109 sql.push_str(check);
110 sql.push(')');
111 }
112
113 sql
114 }
115
116 pub fn to_mssql_sql(&self) -> Vec<String> {
118 let mut sqls = Vec::new();
119
120 let func_name = format!("fn_rls_{}", self.name);
122 if let Some(ref using) = self.using {
123 sqls.push(format!(
124 "CREATE FUNCTION dbo.{fn}(@tenant_id INT) \
125 RETURNS TABLE WITH SCHEMABINDING AS \
126 RETURN SELECT 1 AS result WHERE {expr}",
127 fn = func_name,
128 expr = using
129 ));
130 }
131
132 sqls.push(format!(
134 "CREATE SECURITY POLICY {name}_policy \
135 ADD FILTER PREDICATE dbo.{fn}(tenant_id) ON dbo.{table}, \
136 ADD BLOCK PREDICATE dbo.{fn}(tenant_id) ON dbo.{table} \
137 WITH (STATE = ON)",
138 name = self.name,
139 fn = func_name,
140 table = self.table
141 ));
142
143 sqls
144 }
145
146 pub fn to_drop_sql(&self, db_type: DatabaseType) -> String {
148 match db_type {
149 DatabaseType::PostgreSQL => {
150 format!("DROP POLICY IF EXISTS {} ON {}", self.name, self.table)
151 }
152 DatabaseType::MSSQL => format!("DROP SECURITY POLICY IF EXISTS {}_policy", self.name),
153 _ => String::new(),
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct RlsPolicyBuilder {
161 name: String,
162 table: String,
163 command: PolicyCommand,
164 roles: Vec<String>,
165 using: Option<String>,
166 with_check: Option<String>,
167 permissive: bool,
168}
169
170impl RlsPolicyBuilder {
171 pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
173 Self {
174 name: name.into(),
175 table: table.into(),
176 command: PolicyCommand::All,
177 roles: vec!["PUBLIC".to_string()],
178 using: None,
179 with_check: None,
180 permissive: true,
181 }
182 }
183
184 pub fn for_command(mut self, cmd: PolicyCommand) -> Self {
186 self.command = cmd;
187 self
188 }
189
190 pub fn for_select(self) -> Self {
192 self.for_command(PolicyCommand::Select)
193 }
194
195 pub fn for_insert(self) -> Self {
197 self.for_command(PolicyCommand::Insert)
198 }
199
200 pub fn for_update(self) -> Self {
202 self.for_command(PolicyCommand::Update)
203 }
204
205 pub fn for_delete(self) -> Self {
207 self.for_command(PolicyCommand::Delete)
208 }
209
210 pub fn to_roles<I, S>(mut self, roles: I) -> Self
212 where
213 I: IntoIterator<Item = S>,
214 S: Into<String>,
215 {
216 self.roles = roles.into_iter().map(Into::into).collect();
217 self
218 }
219
220 pub fn using(mut self, expr: impl Into<String>) -> Self {
222 self.using = Some(expr.into());
223 self
224 }
225
226 pub fn with_check(mut self, expr: impl Into<String>) -> Self {
228 self.with_check = Some(expr.into());
229 self
230 }
231
232 pub fn restrictive(mut self) -> Self {
234 self.permissive = false;
235 self
236 }
237
238 pub fn permissive(mut self) -> Self {
240 self.permissive = true;
241 self
242 }
243
244 pub fn build(self) -> RlsPolicy {
246 RlsPolicy {
247 name: self.name,
248 table: self.table,
249 command: self.command,
250 roles: self.roles,
251 using: self.using,
252 with_check: self.with_check,
253 permissive: self.permissive,
254 }
255 }
256}
257
258#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
260pub struct TenantPolicy {
261 pub table: String,
263 pub tenant_column: String,
265 pub tenant_source: TenantSource,
267}
268
269#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271pub enum TenantSource {
272 SessionVar(String),
274 SessionContext(String),
276 Function(String),
278}
279
280impl TenantPolicy {
281 pub fn new(
283 table: impl Into<String>,
284 tenant_column: impl Into<String>,
285 source: TenantSource,
286 ) -> Self {
287 Self {
288 table: table.into(),
289 tenant_column: tenant_column.into(),
290 tenant_source: source,
291 }
292 }
293
294 pub fn to_postgres_rls(&self) -> RlsPolicy {
296 let tenant_expr = match &self.tenant_source {
297 TenantSource::SessionVar(var) => format!("current_setting('{}')", var),
298 TenantSource::Function(func) => format!("{}()", func),
299 TenantSource::SessionContext(key) => format!("current_setting('{}')", key),
300 };
301
302 RlsPolicy::new(format!("{}_tenant_isolation", self.table), &self.table)
303 .using(format!("{} = {}::INT", self.tenant_column, tenant_expr))
304 .with_check(format!("{} = {}::INT", self.tenant_column, tenant_expr))
305 .build()
306 }
307
308 pub fn set_tenant_sql(&self, tenant_id: &str, db_type: DatabaseType) -> String {
310 match db_type {
311 DatabaseType::PostgreSQL => match &self.tenant_source {
312 TenantSource::SessionVar(var) => {
313 format!("SET LOCAL {} = '{}'", var, tenant_id)
314 }
315 _ => format!("SELECT set_config('app.tenant_id', '{}', true)", tenant_id),
316 },
317 DatabaseType::MSSQL => {
318 format!(
319 "EXEC sp_set_session_context @key = N'tenant_id', @value = {}",
320 tenant_id
321 )
322 }
323 _ => String::new(),
324 }
325 }
326}
327
328#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
334pub struct Role {
335 pub name: String,
337 pub login: bool,
339 pub password: Option<String>,
341 pub inherit_from: Vec<String>,
343 pub superuser: bool,
345 pub createdb: bool,
347 pub createrole: bool,
349 pub connection_limit: Option<i32>,
351 pub valid_until: Option<String>,
353}
354
355impl Role {
356 pub fn new(name: impl Into<String>) -> RoleBuilder {
358 RoleBuilder::new(name)
359 }
360
361 pub fn to_postgres_sql(&self) -> String {
363 let mut sql = format!("CREATE ROLE {}", self.name);
364 let mut options = Vec::new();
365
366 if self.login {
367 options.push("LOGIN".to_string());
368 } else {
369 options.push("NOLOGIN".to_string());
370 }
371
372 if let Some(ref pwd) = self.password {
373 options.push(format!("PASSWORD '{}'", pwd));
374 }
375
376 if self.superuser {
377 options.push("SUPERUSER".to_string());
378 }
379
380 if self.createdb {
381 options.push("CREATEDB".to_string());
382 }
383
384 if self.createrole {
385 options.push("CREATEROLE".to_string());
386 }
387
388 if let Some(limit) = self.connection_limit {
389 options.push(format!("CONNECTION LIMIT {}", limit));
390 }
391
392 if let Some(ref until) = self.valid_until {
393 options.push(format!("VALID UNTIL '{}'", until));
394 }
395
396 if !self.inherit_from.is_empty() {
397 options.push(format!("IN ROLE {}", self.inherit_from.join(", ")));
398 }
399
400 if !options.is_empty() {
401 sql.push_str(" WITH ");
402 sql.push_str(&options.join(" "));
403 }
404
405 sql
406 }
407
408 pub fn to_mysql_sql(&self) -> Vec<String> {
410 let mut sqls = Vec::new();
411
412 if self.login {
413 let mut sql = format!("CREATE USER '{}'@'%'", self.name);
414 if let Some(ref pwd) = self.password {
415 sql.push_str(&format!(" IDENTIFIED BY '{}'", pwd));
416 }
417 sqls.push(sql);
418 } else {
419 sqls.push(format!("CREATE ROLE '{}'", self.name));
420 }
421
422 for parent in &self.inherit_from {
423 sqls.push(format!("GRANT '{}' TO '{}'", parent, self.name));
424 }
425
426 sqls
427 }
428
429 pub fn to_mssql_sql(&self, database: &str) -> Vec<String> {
431 let mut sqls = Vec::new();
432
433 if self.login {
434 let mut sql = format!("CREATE LOGIN {} WITH PASSWORD = ", self.name);
435 if let Some(ref pwd) = self.password {
436 sql.push_str(&format!("'{}'", pwd));
437 } else {
438 sql.push_str("''");
439 }
440 sqls.push(sql);
441 sqls.push(format!(
442 "USE {}; CREATE USER {} FOR LOGIN {}",
443 database, self.name, self.name
444 ));
445 } else {
446 sqls.push(format!("USE {}; CREATE ROLE {}", database, self.name));
447 }
448
449 for parent in &self.inherit_from {
450 sqls.push(format!("ALTER ROLE {} ADD MEMBER {}", parent, self.name));
451 }
452
453 sqls
454 }
455}
456
457#[derive(Debug, Clone)]
459pub struct RoleBuilder {
460 name: String,
461 login: bool,
462 password: Option<String>,
463 inherit_from: Vec<String>,
464 superuser: bool,
465 createdb: bool,
466 createrole: bool,
467 connection_limit: Option<i32>,
468 valid_until: Option<String>,
469}
470
471impl RoleBuilder {
472 pub fn new(name: impl Into<String>) -> Self {
474 Self {
475 name: name.into(),
476 login: false,
477 password: None,
478 inherit_from: Vec::new(),
479 superuser: false,
480 createdb: false,
481 createrole: false,
482 connection_limit: None,
483 valid_until: None,
484 }
485 }
486
487 pub fn login(mut self) -> Self {
489 self.login = true;
490 self
491 }
492
493 pub fn password(mut self, pwd: impl Into<String>) -> Self {
495 self.password = Some(pwd.into());
496 self.login = true;
497 self
498 }
499
500 pub fn inherit<I, S>(mut self, roles: I) -> Self
502 where
503 I: IntoIterator<Item = S>,
504 S: Into<String>,
505 {
506 self.inherit_from = roles.into_iter().map(Into::into).collect();
507 self
508 }
509
510 pub fn superuser(mut self) -> Self {
512 self.superuser = true;
513 self
514 }
515
516 pub fn createdb(mut self) -> Self {
518 self.createdb = true;
519 self
520 }
521
522 pub fn createrole(mut self) -> Self {
524 self.createrole = true;
525 self
526 }
527
528 pub fn connection_limit(mut self, limit: i32) -> Self {
530 self.connection_limit = Some(limit);
531 self
532 }
533
534 pub fn valid_until(mut self, timestamp: impl Into<String>) -> Self {
536 self.valid_until = Some(timestamp.into());
537 self
538 }
539
540 pub fn build(self) -> Role {
542 Role {
543 name: self.name,
544 login: self.login,
545 password: self.password,
546 inherit_from: self.inherit_from,
547 superuser: self.superuser,
548 createdb: self.createdb,
549 createrole: self.createrole,
550 connection_limit: self.connection_limit,
551 valid_until: self.valid_until,
552 }
553 }
554}
555
556#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
562pub struct Grant {
563 pub privileges: Vec<Privilege>,
565 pub object: GrantObject,
567 pub grantee: String,
569 pub with_grant_option: bool,
571}
572
573#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
575pub enum Privilege {
576 Select,
578 Insert,
580 Update,
582 Delete,
584 Truncate,
586 References,
588 Trigger,
590 All,
592 Execute,
594 Usage,
596 Create,
598 Connect,
600}
601
602impl Privilege {
603 pub fn to_sql(&self) -> &'static str {
605 match self {
606 Self::Select => "SELECT",
607 Self::Insert => "INSERT",
608 Self::Update => "UPDATE",
609 Self::Delete => "DELETE",
610 Self::Truncate => "TRUNCATE",
611 Self::References => "REFERENCES",
612 Self::Trigger => "TRIGGER",
613 Self::All => "ALL PRIVILEGES",
614 Self::Execute => "EXECUTE",
615 Self::Usage => "USAGE",
616 Self::Create => "CREATE",
617 Self::Connect => "CONNECT",
618 }
619 }
620}
621
622#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
624pub enum GrantObject {
625 Table {
627 name: String,
628 columns: Option<Vec<String>>,
629 },
630 Schema(String),
632 Database(String),
634 Sequence(String),
636 Function { name: String, args: String },
638 AllTablesInSchema(String),
640 AllSequencesInSchema(String),
642}
643
644impl GrantObject {
645 pub fn table(name: impl Into<String>) -> Self {
647 Self::Table {
648 name: name.into(),
649 columns: None,
650 }
651 }
652
653 pub fn table_columns<I, S>(name: impl Into<String>, columns: I) -> Self
655 where
656 I: IntoIterator<Item = S>,
657 S: Into<String>,
658 {
659 Self::Table {
660 name: name.into(),
661 columns: Some(columns.into_iter().map(Into::into).collect()),
662 }
663 }
664
665 pub fn schema(name: impl Into<String>) -> Self {
667 Self::Schema(name.into())
668 }
669
670 pub fn to_sql(&self) -> String {
672 match self {
673 Self::Table { name, columns } => {
674 if let Some(cols) = columns {
675 format!("TABLE {} ({})", name, cols.join(", "))
676 } else {
677 format!("TABLE {}", name)
678 }
679 }
680 Self::Schema(name) => format!("SCHEMA {}", name),
681 Self::Database(name) => format!("DATABASE {}", name),
682 Self::Sequence(name) => format!("SEQUENCE {}", name),
683 Self::Function { name, args } => format!("FUNCTION {}({})", name, args),
684 Self::AllTablesInSchema(schema) => format!("ALL TABLES IN SCHEMA {}", schema),
685 Self::AllSequencesInSchema(schema) => format!("ALL SEQUENCES IN SCHEMA {}", schema),
686 }
687 }
688}
689
690impl Grant {
691 pub fn new(grantee: impl Into<String>) -> GrantBuilder {
693 GrantBuilder::new(grantee)
694 }
695
696 pub fn to_postgres_sql(&self) -> String {
698 let privs: Vec<&str> = self.privileges.iter().map(Privilege::to_sql).collect();
699 let priv_sql = match &self.object {
700 GrantObject::Table {
701 columns: Some(cols),
702 ..
703 } => {
704 privs
706 .iter()
707 .map(|p| format!("{} ({})", p, cols.join(", ")))
708 .collect::<Vec<_>>()
709 .join(", ")
710 }
711 _ => privs.join(", "),
712 };
713
714 let obj_sql = match &self.object {
715 GrantObject::Table {
716 name,
717 columns: Some(_),
718 } => format!("TABLE {}", name),
719 _ => self.object.to_sql(),
720 };
721
722 let mut sql = format!("GRANT {} ON {} TO {}", priv_sql, obj_sql, self.grantee);
723
724 if self.with_grant_option {
725 sql.push_str(" WITH GRANT OPTION");
726 }
727
728 sql
729 }
730
731 pub fn to_mysql_sql(&self) -> String {
733 let privs: Vec<&str> = self.privileges.iter().map(Privilege::to_sql).collect();
734 let priv_sql = match &self.object {
735 GrantObject::Table {
736 columns: Some(cols),
737 ..
738 } => privs
739 .iter()
740 .map(|p| format!("{} ({})", p, cols.join(", ")))
741 .collect::<Vec<_>>()
742 .join(", "),
743 _ => privs.join(", "),
744 };
745
746 let obj = match &self.object {
747 GrantObject::Table { name, .. } => name.clone(),
748 GrantObject::Database(name) => format!("{}.*", name),
749 GrantObject::Schema(name) => format!("{}.*", name),
750 _ => "*.*".to_string(),
751 };
752
753 let mut sql = format!("GRANT {} ON {} TO '{}'@'%'", priv_sql, obj, self.grantee);
754
755 if self.with_grant_option {
756 sql.push_str(" WITH GRANT OPTION");
757 }
758
759 sql
760 }
761
762 pub fn to_mssql_sql(&self) -> String {
764 let privs: Vec<&str> = self.privileges.iter().map(Privilege::to_sql).collect();
765
766 let (obj_type, obj_name) = match &self.object {
767 GrantObject::Table { name, columns } => {
768 if let Some(cols) = columns {
769 return format!(
770 "GRANT {} ({}) ON {} TO {}",
771 privs.join(", "),
772 cols.join(", "),
773 name,
774 self.grantee
775 );
776 }
777 ("OBJECT", name.clone())
778 }
779 GrantObject::Schema(name) => ("SCHEMA", name.clone()),
780 GrantObject::Database(name) => ("DATABASE", name.clone()),
781 _ => ("OBJECT", "".to_string()),
782 };
783
784 format!(
785 "GRANT {} ON {}::{} TO {}",
786 privs.join(", "),
787 obj_type,
788 obj_name,
789 self.grantee
790 )
791 }
792}
793
794#[derive(Debug, Clone)]
796pub struct GrantBuilder {
797 grantee: String,
798 privileges: Vec<Privilege>,
799 object: Option<GrantObject>,
800 with_grant_option: bool,
801}
802
803impl GrantBuilder {
804 pub fn new(grantee: impl Into<String>) -> Self {
806 Self {
807 grantee: grantee.into(),
808 privileges: Vec::new(),
809 object: None,
810 with_grant_option: false,
811 }
812 }
813
814 pub fn privilege(mut self, priv_: Privilege) -> Self {
816 self.privileges.push(priv_);
817 self
818 }
819
820 pub fn select(self) -> Self {
822 self.privilege(Privilege::Select)
823 }
824
825 pub fn insert(self) -> Self {
827 self.privilege(Privilege::Insert)
828 }
829
830 pub fn update(self) -> Self {
832 self.privilege(Privilege::Update)
833 }
834
835 pub fn delete(self) -> Self {
837 self.privilege(Privilege::Delete)
838 }
839
840 pub fn all(self) -> Self {
842 self.privilege(Privilege::All)
843 }
844
845 pub fn on(mut self, object: GrantObject) -> Self {
847 self.object = Some(object);
848 self
849 }
850
851 pub fn on_table(self, table: impl Into<String>) -> Self {
853 self.on(GrantObject::table(table))
854 }
855
856 pub fn on_columns<I, S>(self, table: impl Into<String>, columns: I) -> Self
858 where
859 I: IntoIterator<Item = S>,
860 S: Into<String>,
861 {
862 self.on(GrantObject::table_columns(table, columns))
863 }
864
865 pub fn on_schema(self, schema: impl Into<String>) -> Self {
867 self.on(GrantObject::Schema(schema.into()))
868 }
869
870 pub fn with_grant_option(mut self) -> Self {
872 self.with_grant_option = true;
873 self
874 }
875
876 pub fn build(self) -> QueryResult<Grant> {
878 let object = self.object.ok_or_else(|| {
879 QueryError::invalid_input(
880 "object",
881 "Grant requires an object (use on_table, on_schema, etc.)",
882 )
883 })?;
884
885 if self.privileges.is_empty() {
886 return Err(QueryError::invalid_input(
887 "privileges",
888 "Grant requires at least one privilege",
889 ));
890 }
891
892 Ok(Grant {
893 privileges: self.privileges,
894 object,
895 grantee: self.grantee,
896 with_grant_option: self.with_grant_option,
897 })
898 }
899}
900
901#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
907pub struct DataMask {
908 pub table: String,
910 pub column: String,
912 pub mask_function: MaskFunction,
914}
915
916#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
918pub enum MaskFunction {
919 Default,
921 Email,
923 Partial {
925 prefix: usize,
926 padding: String,
927 suffix: usize,
928 },
929 Random,
931 Custom(String),
933 Null,
935}
936
937impl DataMask {
938 pub fn new(table: impl Into<String>, column: impl Into<String>, mask: MaskFunction) -> Self {
940 Self {
941 table: table.into(),
942 column: column.into(),
943 mask_function: mask,
944 }
945 }
946
947 pub fn to_postgres_view(&self, view_name: &str) -> String {
949 let masked_expr = match &self.mask_function {
950 MaskFunction::Default => format!(
951 "CASE WHEN current_user = 'admin' THEN {} ELSE '****' END",
952 self.column
953 ),
954 MaskFunction::Email => format!(
955 "CASE WHEN current_user = 'admin' THEN {} ELSE \
956 CONCAT(LEFT({}, 1), '***@', SPLIT_PART({}, '@', 2)) END",
957 self.column, self.column, self.column
958 ),
959 MaskFunction::Partial {
960 prefix,
961 padding,
962 suffix,
963 } => format!(
964 "CONCAT(LEFT({}, {}), '{}', RIGHT({}, {}))",
965 self.column, prefix, padding, self.column, suffix
966 ),
967 MaskFunction::Null => "NULL".to_string(),
968 MaskFunction::Custom(func) => format!("{}({})", func, self.column),
969 MaskFunction::Random => format!("md5(random()::text)"),
970 };
971
972 format!(
973 "CREATE OR REPLACE VIEW {} AS SELECT *, {} AS {}_masked FROM {}",
974 view_name, masked_expr, self.column, self.table
975 )
976 }
977
978 pub fn to_mssql_alter(&self) -> String {
980 let mask_func = match &self.mask_function {
981 MaskFunction::Default => "default()".to_string(),
982 MaskFunction::Email => "email()".to_string(),
983 MaskFunction::Partial {
984 prefix,
985 padding,
986 suffix,
987 } => {
988 format!("partial({}, '{}', {})", prefix, padding, suffix)
989 }
990 MaskFunction::Random => "random(1, 100)".to_string(),
991 MaskFunction::Custom(func) => func.clone(),
992 MaskFunction::Null => "default()".to_string(),
993 };
994
995 format!(
996 "ALTER TABLE {} ALTER COLUMN {} ADD MASKED WITH (FUNCTION = '{}')",
997 self.table, self.column, mask_func
998 )
999 }
1000}
1001
1002#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1008pub struct ConnectionProfile {
1009 pub name: String,
1011 pub role: String,
1013 pub search_path: Vec<String>,
1015 pub session_vars: Vec<(String, String)>,
1017 pub read_only: bool,
1019 pub statement_timeout: Option<u32>,
1021 pub lock_timeout: Option<u32>,
1023}
1024
1025impl ConnectionProfile {
1026 pub fn new(name: impl Into<String>, role: impl Into<String>) -> ConnectionProfileBuilder {
1028 ConnectionProfileBuilder::new(name, role)
1029 }
1030
1031 pub fn to_postgres_setup(&self) -> Vec<String> {
1033 let mut sqls = Vec::new();
1034
1035 sqls.push(format!("SET ROLE {}", self.role));
1036
1037 if !self.search_path.is_empty() {
1038 sqls.push(format!(
1039 "SET search_path TO {}",
1040 self.search_path.join(", ")
1041 ));
1042 }
1043
1044 if self.read_only {
1045 sqls.push("SET default_transaction_read_only = ON".to_string());
1046 }
1047
1048 if let Some(timeout) = self.statement_timeout {
1049 sqls.push(format!("SET statement_timeout = {}", timeout));
1050 }
1051
1052 if let Some(timeout) = self.lock_timeout {
1053 sqls.push(format!("SET lock_timeout = {}", timeout));
1054 }
1055
1056 for (key, value) in &self.session_vars {
1057 sqls.push(format!("SET {} = '{}'", key, value));
1058 }
1059
1060 sqls
1061 }
1062
1063 pub fn to_mysql_setup(&self) -> Vec<String> {
1065 let mut sqls = Vec::new();
1066
1067 if self.read_only {
1069 sqls.push("SET SESSION TRANSACTION READ ONLY".to_string());
1070 }
1071
1072 for (key, value) in &self.session_vars {
1073 sqls.push(format!("SET @{} = '{}'", key, value));
1074 }
1075
1076 sqls
1077 }
1078}
1079
1080#[derive(Debug, Clone)]
1082pub struct ConnectionProfileBuilder {
1083 name: String,
1084 role: String,
1085 search_path: Vec<String>,
1086 session_vars: Vec<(String, String)>,
1087 read_only: bool,
1088 statement_timeout: Option<u32>,
1089 lock_timeout: Option<u32>,
1090}
1091
1092impl ConnectionProfileBuilder {
1093 pub fn new(name: impl Into<String>, role: impl Into<String>) -> Self {
1095 Self {
1096 name: name.into(),
1097 role: role.into(),
1098 search_path: Vec::new(),
1099 session_vars: Vec::new(),
1100 read_only: false,
1101 statement_timeout: None,
1102 lock_timeout: None,
1103 }
1104 }
1105
1106 pub fn search_path<I, S>(mut self, schemas: I) -> Self
1108 where
1109 I: IntoIterator<Item = S>,
1110 S: Into<String>,
1111 {
1112 self.search_path = schemas.into_iter().map(Into::into).collect();
1113 self
1114 }
1115
1116 pub fn session_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1118 self.session_vars.push((key.into(), value.into()));
1119 self
1120 }
1121
1122 pub fn read_only(mut self) -> Self {
1124 self.read_only = true;
1125 self
1126 }
1127
1128 pub fn statement_timeout(mut self, ms: u32) -> Self {
1130 self.statement_timeout = Some(ms);
1131 self
1132 }
1133
1134 pub fn lock_timeout(mut self, ms: u32) -> Self {
1136 self.lock_timeout = Some(ms);
1137 self
1138 }
1139
1140 pub fn build(self) -> ConnectionProfile {
1142 ConnectionProfile {
1143 name: self.name,
1144 role: self.role,
1145 search_path: self.search_path,
1146 session_vars: self.session_vars,
1147 read_only: self.read_only,
1148 statement_timeout: self.statement_timeout,
1149 lock_timeout: self.lock_timeout,
1150 }
1151 }
1152}
1153
1154pub mod mongodb {
1160 use serde::{Deserialize, Serialize};
1161 use serde_json::Value as JsonValue;
1162
1163 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1165 pub struct MongoRole {
1166 pub role: String,
1168 pub db: String,
1170 pub privileges: Vec<MongoPrivilege>,
1172 pub roles: Vec<InheritedRole>,
1174 }
1175
1176 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1178 pub struct MongoPrivilege {
1179 pub resource: MongoResource,
1181 pub actions: Vec<String>,
1183 }
1184
1185 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1187 #[serde(untagged)]
1188 pub enum MongoResource {
1189 Collection { db: String, collection: String },
1191 Database { db: String },
1193 Cluster { cluster: bool },
1195 }
1196
1197 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1199 pub struct InheritedRole {
1200 pub role: String,
1202 pub db: String,
1204 }
1205
1206 impl MongoRole {
1207 pub fn new(role: impl Into<String>, db: impl Into<String>) -> MongoRoleBuilder {
1209 MongoRoleBuilder::new(role, db)
1210 }
1211
1212 pub fn to_create_command(&self) -> JsonValue {
1214 let privileges: Vec<JsonValue> = self
1215 .privileges
1216 .iter()
1217 .map(|p| {
1218 let resource = match &p.resource {
1219 MongoResource::Collection { db, collection } => {
1220 serde_json::json!({ "db": db, "collection": collection })
1221 }
1222 MongoResource::Database { db } => {
1223 serde_json::json!({ "db": db, "collection": "" })
1224 }
1225 MongoResource::Cluster { .. } => {
1226 serde_json::json!({ "cluster": true })
1227 }
1228 };
1229 serde_json::json!({
1230 "resource": resource,
1231 "actions": p.actions
1232 })
1233 })
1234 .collect();
1235
1236 let roles: Vec<JsonValue> = self
1237 .roles
1238 .iter()
1239 .map(|r| serde_json::json!({ "role": r.role, "db": r.db }))
1240 .collect();
1241
1242 serde_json::json!({
1243 "createRole": self.role,
1244 "privileges": privileges,
1245 "roles": roles
1246 })
1247 }
1248 }
1249
1250 #[derive(Debug, Clone, Default)]
1252 pub struct MongoRoleBuilder {
1253 role: String,
1254 db: String,
1255 privileges: Vec<MongoPrivilege>,
1256 roles: Vec<InheritedRole>,
1257 }
1258
1259 impl MongoRoleBuilder {
1260 pub fn new(role: impl Into<String>, db: impl Into<String>) -> Self {
1262 Self {
1263 role: role.into(),
1264 db: db.into(),
1265 privileges: Vec::new(),
1266 roles: Vec::new(),
1267 }
1268 }
1269
1270 pub fn privilege_collection<I, S>(
1272 mut self,
1273 collection: impl Into<String>,
1274 actions: I,
1275 ) -> Self
1276 where
1277 I: IntoIterator<Item = S>,
1278 S: Into<String>,
1279 {
1280 self.privileges.push(MongoPrivilege {
1281 resource: MongoResource::Collection {
1282 db: self.db.clone(),
1283 collection: collection.into(),
1284 },
1285 actions: actions.into_iter().map(Into::into).collect(),
1286 });
1287 self
1288 }
1289
1290 pub fn privilege_database<I, S>(mut self, actions: I) -> Self
1292 where
1293 I: IntoIterator<Item = S>,
1294 S: Into<String>,
1295 {
1296 self.privileges.push(MongoPrivilege {
1297 resource: MongoResource::Database {
1298 db: self.db.clone(),
1299 },
1300 actions: actions.into_iter().map(Into::into).collect(),
1301 });
1302 self
1303 }
1304
1305 pub fn inherit(mut self, role: impl Into<String>, db: impl Into<String>) -> Self {
1307 self.roles.push(InheritedRole {
1308 role: role.into(),
1309 db: db.into(),
1310 });
1311 self
1312 }
1313
1314 pub fn build(self) -> MongoRole {
1316 MongoRole {
1317 role: self.role,
1318 db: self.db,
1319 privileges: self.privileges,
1320 roles: self.roles,
1321 }
1322 }
1323 }
1324
1325 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1327 pub struct FieldEncryption {
1328 pub key_vault_namespace: String,
1330 pub kms_providers: KmsProviders,
1332 pub schema_map: serde_json::Map<String, JsonValue>,
1334 }
1335
1336 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1338 pub enum KmsProviders {
1339 Local { key: String },
1341 Aws {
1343 access_key_id: String,
1344 secret_access_key: String,
1345 region: String,
1346 },
1347 Azure {
1349 tenant_id: String,
1350 client_id: String,
1351 client_secret: String,
1352 },
1353 Gcp { email: String, private_key: String },
1355 }
1356
1357 impl FieldEncryption {
1358 pub fn new(key_vault_namespace: impl Into<String>, kms: KmsProviders) -> Self {
1360 Self {
1361 key_vault_namespace: key_vault_namespace.into(),
1362 kms_providers: kms,
1363 schema_map: serde_json::Map::new(),
1364 }
1365 }
1366
1367 pub fn encrypt_field(
1369 mut self,
1370 namespace: impl Into<String>,
1371 field: impl Into<String>,
1372 algorithm: EncryptionAlgorithm,
1373 key_id: impl Into<String>,
1374 ) -> Self {
1375 let ns = namespace.into();
1376 let field = field.into();
1377
1378 let field_spec = serde_json::json!({
1379 "encrypt": {
1380 "bsonType": "string",
1381 "algorithm": algorithm.to_str(),
1382 "keyId": [{ "$binary": { "base64": key_id.into(), "subType": "04" } }]
1383 }
1384 });
1385
1386 let schema = self.schema_map.entry(ns).or_insert_with(|| {
1388 serde_json::json!({
1389 "bsonType": "object",
1390 "properties": {}
1391 })
1392 });
1393
1394 if let Some(props) = schema.get_mut("properties").and_then(|p| p.as_object_mut()) {
1395 props.insert(field, field_spec);
1396 }
1397
1398 self
1399 }
1400
1401 pub fn to_options(&self) -> JsonValue {
1403 let kms = match &self.kms_providers {
1404 KmsProviders::Local { key } => {
1405 serde_json::json!({ "local": { "key": key } })
1406 }
1407 KmsProviders::Aws {
1408 access_key_id,
1409 secret_access_key,
1410 region,
1411 } => {
1412 serde_json::json!({
1413 "aws": {
1414 "accessKeyId": access_key_id,
1415 "secretAccessKey": secret_access_key,
1416 "region": region
1417 }
1418 })
1419 }
1420 KmsProviders::Azure {
1421 tenant_id,
1422 client_id,
1423 client_secret,
1424 } => {
1425 serde_json::json!({
1426 "azure": {
1427 "tenantId": tenant_id,
1428 "clientId": client_id,
1429 "clientSecret": client_secret
1430 }
1431 })
1432 }
1433 KmsProviders::Gcp { email, private_key } => {
1434 serde_json::json!({
1435 "gcp": {
1436 "email": email,
1437 "privateKey": private_key
1438 }
1439 })
1440 }
1441 };
1442
1443 serde_json::json!({
1444 "keyVaultNamespace": self.key_vault_namespace,
1445 "kmsProviders": kms,
1446 "schemaMap": self.schema_map
1447 })
1448 }
1449 }
1450
1451 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1453 pub enum EncryptionAlgorithm {
1454 Deterministic,
1456 Random,
1458 }
1459
1460 impl EncryptionAlgorithm {
1461 pub fn to_str(&self) -> &'static str {
1463 match self {
1464 Self::Deterministic => "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic",
1465 Self::Random => "AEAD_AES_256_CBC_HMAC_SHA_512-Random",
1466 }
1467 }
1468 }
1469}
1470
1471#[cfg(test)]
1472mod tests {
1473 use super::*;
1474
1475 #[test]
1476 fn test_rls_policy_postgres() {
1477 let policy = RlsPolicy::new("tenant_isolation", "orders")
1478 .using("tenant_id = current_setting('app.tenant_id')::INT")
1479 .with_check("tenant_id = current_setting('app.tenant_id')::INT")
1480 .build();
1481
1482 let sql = policy.to_postgres_sql();
1483 assert!(sql.contains("CREATE POLICY tenant_isolation ON orders"));
1484 assert!(sql.contains("USING (tenant_id ="));
1485 assert!(sql.contains("WITH CHECK (tenant_id ="));
1486 }
1487
1488 #[test]
1489 fn test_rls_policy_for_select() {
1490 let policy = RlsPolicy::new("read_own", "documents")
1491 .for_select()
1492 .to_roles(["app_user"])
1493 .using("owner_id = current_user_id()")
1494 .build();
1495
1496 let sql = policy.to_postgres_sql();
1497 assert!(sql.contains("FOR SELECT"));
1498 assert!(sql.contains("TO app_user"));
1499 }
1500
1501 #[test]
1502 fn test_tenant_policy() {
1503 let tenant = TenantPolicy::new(
1504 "orders",
1505 "tenant_id",
1506 TenantSource::SessionVar("app.tenant_id".to_string()),
1507 );
1508
1509 let policy = tenant.to_postgres_rls();
1510 assert!(policy.using.is_some());
1511 assert!(policy.with_check.is_some());
1512
1513 let set_sql = tenant.set_tenant_sql("123", DatabaseType::PostgreSQL);
1514 assert!(set_sql.contains("SET LOCAL app.tenant_id"));
1515 }
1516
1517 #[test]
1518 fn test_role_postgres() {
1519 let role = Role::new("app_reader")
1520 .login()
1521 .password("secret")
1522 .connection_limit(10)
1523 .build();
1524
1525 let sql = role.to_postgres_sql();
1526 assert!(sql.contains("CREATE ROLE app_reader"));
1527 assert!(sql.contains("LOGIN"));
1528 assert!(sql.contains("PASSWORD 'secret'"));
1529 assert!(sql.contains("CONNECTION LIMIT 10"));
1530 }
1531
1532 #[test]
1533 fn test_role_inherit() {
1534 let role = Role::new("senior_dev")
1535 .inherit(["developer", "analyst"])
1536 .build();
1537
1538 let sql = role.to_postgres_sql();
1539 assert!(sql.contains("IN ROLE developer, analyst"));
1540 }
1541
1542 #[test]
1543 fn test_grant_table() {
1544 let grant = Grant::new("app_user")
1545 .select()
1546 .insert()
1547 .update()
1548 .on_table("users")
1549 .build()
1550 .unwrap();
1551
1552 let sql = grant.to_postgres_sql();
1553 assert!(sql.contains("GRANT SELECT, INSERT, UPDATE ON TABLE users TO app_user"));
1554 }
1555
1556 #[test]
1557 fn test_grant_columns() {
1558 let grant = Grant::new("restricted_user")
1559 .select()
1560 .on_columns("users", ["id", "name", "email"])
1561 .build()
1562 .unwrap();
1563
1564 let sql = grant.to_postgres_sql();
1565 assert!(sql.contains("SELECT (id, name, email)"));
1566 }
1567
1568 #[test]
1569 fn test_grant_with_option() {
1570 let grant = Grant::new("admin")
1571 .all()
1572 .on_schema("public")
1573 .with_grant_option()
1574 .build()
1575 .unwrap();
1576
1577 let sql = grant.to_postgres_sql();
1578 assert!(sql.contains("WITH GRANT OPTION"));
1579 }
1580
1581 #[test]
1582 fn test_data_mask_email() {
1583 let mask = DataMask::new("users", "email", MaskFunction::Email);
1584 let sql = mask.to_mssql_alter();
1585
1586 assert!(sql.contains("ADD MASKED WITH (FUNCTION = 'email()'"));
1587 }
1588
1589 #[test]
1590 fn test_data_mask_partial() {
1591 let mask = DataMask::new(
1592 "users",
1593 "ssn",
1594 MaskFunction::Partial {
1595 prefix: 0,
1596 padding: "XXX-XX-".to_string(),
1597 suffix: 4,
1598 },
1599 );
1600 let sql = mask.to_mssql_alter();
1601
1602 assert!(sql.contains("partial(0, 'XXX-XX-', 4)"));
1603 }
1604
1605 #[test]
1606 fn test_connection_profile() {
1607 let profile = ConnectionProfile::new("readonly_user", "app_readonly")
1608 .search_path(["app", "public"])
1609 .read_only()
1610 .statement_timeout(5000)
1611 .build();
1612
1613 let sqls = profile.to_postgres_setup();
1614 assert!(sqls.iter().any(|s| s.contains("SET ROLE app_readonly")));
1615 assert!(
1616 sqls.iter()
1617 .any(|s| s.contains("search_path TO app, public"))
1618 );
1619 assert!(sqls.iter().any(|s| s.contains("read_only = ON")));
1620 assert!(sqls.iter().any(|s| s.contains("statement_timeout = 5000")));
1621 }
1622
1623 mod mongodb_tests {
1624 use super::super::mongodb::*;
1625
1626 #[test]
1627 fn test_mongo_role() {
1628 let role = MongoRole::new("app_reader", "mydb")
1629 .privilege_collection("orders", ["find", "aggregate"])
1630 .inherit("read", "mydb")
1631 .build();
1632
1633 let cmd = role.to_create_command();
1634 assert_eq!(cmd["createRole"], "app_reader");
1635 assert!(cmd["privileges"].is_array());
1636 assert!(cmd["roles"].is_array());
1637 }
1638
1639 #[test]
1640 fn test_field_encryption_local() {
1641 let enc = FieldEncryption::new(
1642 "encryption.__keyVault",
1643 KmsProviders::Local {
1644 key: "base64key".to_string(),
1645 },
1646 )
1647 .encrypt_field(
1648 "mydb.users",
1649 "ssn",
1650 EncryptionAlgorithm::Deterministic,
1651 "keyid",
1652 );
1653
1654 let opts = enc.to_options();
1655 assert!(opts["kmsProviders"]["local"].is_object());
1656 assert!(opts["schemaMap"]["mydb.users"].is_object());
1657 }
1658
1659 #[test]
1660 fn test_field_encryption_aws() {
1661 let enc = FieldEncryption::new(
1662 "encryption.__keyVault",
1663 KmsProviders::Aws {
1664 access_key_id: "AKID".to_string(),
1665 secret_access_key: "secret".to_string(),
1666 region: "us-east-1".to_string(),
1667 },
1668 );
1669
1670 let opts = enc.to_options();
1671 assert!(opts["kmsProviders"]["aws"]["accessKeyId"].is_string());
1672 }
1673 }
1674}