1use serde::{Deserialize, Serialize};
21
22use crate::error::{QueryError, QueryResult};
23use crate::sql::DatabaseType;
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct RlsPolicy {
32 pub name: String,
34 pub table: String,
36 pub command: PolicyCommand,
38 pub roles: Vec<String>,
40 pub using: Option<String>,
42 pub with_check: Option<String>,
44 pub permissive: bool,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50pub enum PolicyCommand {
51 All,
53 Select,
55 Insert,
57 Update,
59 Delete,
61}
62
63impl PolicyCommand {
64 pub fn to_sql(&self) -> &'static str {
66 match self {
67 Self::All => "ALL",
68 Self::Select => "SELECT",
69 Self::Insert => "INSERT",
70 Self::Update => "UPDATE",
71 Self::Delete => "DELETE",
72 }
73 }
74}
75
76impl RlsPolicy {
77 pub fn new(name: impl Into<String>, table: impl Into<String>) -> RlsPolicyBuilder {
79 RlsPolicyBuilder::new(name, table)
80 }
81
82 pub fn to_postgres_sql(&self) -> String {
84 let mut sql = format!(
85 "CREATE POLICY {} ON {} AS {} FOR {}",
86 self.name,
87 self.table,
88 if self.permissive { "PERMISSIVE" } else { "RESTRICTIVE" },
89 self.command.to_sql()
90 );
91
92 if !self.roles.is_empty() && self.roles != vec!["PUBLIC"] {
93 sql.push_str(" TO ");
94 sql.push_str(&self.roles.join(", "));
95 }
96
97 if let Some(ref using) = self.using {
98 sql.push_str(" USING (");
99 sql.push_str(using);
100 sql.push(')');
101 }
102
103 if let Some(ref check) = self.with_check {
104 sql.push_str(" WITH CHECK (");
105 sql.push_str(check);
106 sql.push(')');
107 }
108
109 sql
110 }
111
112 pub fn to_mssql_sql(&self) -> Vec<String> {
114 let mut sqls = Vec::new();
115
116 let func_name = format!("fn_rls_{}", self.name);
118 if let Some(ref using) = self.using {
119 sqls.push(format!(
120 "CREATE FUNCTION dbo.{fn}(@tenant_id INT) \
121 RETURNS TABLE WITH SCHEMABINDING AS \
122 RETURN SELECT 1 AS result WHERE {expr}",
123 fn = func_name,
124 expr = using
125 ));
126 }
127
128 sqls.push(format!(
130 "CREATE SECURITY POLICY {name}_policy \
131 ADD FILTER PREDICATE dbo.{fn}(tenant_id) ON dbo.{table}, \
132 ADD BLOCK PREDICATE dbo.{fn}(tenant_id) ON dbo.{table} \
133 WITH (STATE = ON)",
134 name = self.name,
135 fn = func_name,
136 table = self.table
137 ));
138
139 sqls
140 }
141
142 pub fn to_drop_sql(&self, db_type: DatabaseType) -> String {
144 match db_type {
145 DatabaseType::PostgreSQL => format!("DROP POLICY IF EXISTS {} ON {}", self.name, self.table),
146 DatabaseType::MSSQL => format!("DROP SECURITY POLICY IF EXISTS {}_policy", self.name),
147 _ => String::new(),
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct RlsPolicyBuilder {
155 name: String,
156 table: String,
157 command: PolicyCommand,
158 roles: Vec<String>,
159 using: Option<String>,
160 with_check: Option<String>,
161 permissive: bool,
162}
163
164impl RlsPolicyBuilder {
165 pub fn new(name: impl Into<String>, table: impl Into<String>) -> Self {
167 Self {
168 name: name.into(),
169 table: table.into(),
170 command: PolicyCommand::All,
171 roles: vec!["PUBLIC".to_string()],
172 using: None,
173 with_check: None,
174 permissive: true,
175 }
176 }
177
178 pub fn for_command(mut self, cmd: PolicyCommand) -> Self {
180 self.command = cmd;
181 self
182 }
183
184 pub fn for_select(self) -> Self {
186 self.for_command(PolicyCommand::Select)
187 }
188
189 pub fn for_insert(self) -> Self {
191 self.for_command(PolicyCommand::Insert)
192 }
193
194 pub fn for_update(self) -> Self {
196 self.for_command(PolicyCommand::Update)
197 }
198
199 pub fn for_delete(self) -> Self {
201 self.for_command(PolicyCommand::Delete)
202 }
203
204 pub fn to_roles<I, S>(mut self, roles: I) -> Self
206 where
207 I: IntoIterator<Item = S>,
208 S: Into<String>,
209 {
210 self.roles = roles.into_iter().map(Into::into).collect();
211 self
212 }
213
214 pub fn using(mut self, expr: impl Into<String>) -> Self {
216 self.using = Some(expr.into());
217 self
218 }
219
220 pub fn with_check(mut self, expr: impl Into<String>) -> Self {
222 self.with_check = Some(expr.into());
223 self
224 }
225
226 pub fn restrictive(mut self) -> Self {
228 self.permissive = false;
229 self
230 }
231
232 pub fn permissive(mut self) -> Self {
234 self.permissive = true;
235 self
236 }
237
238 pub fn build(self) -> RlsPolicy {
240 RlsPolicy {
241 name: self.name,
242 table: self.table,
243 command: self.command,
244 roles: self.roles,
245 using: self.using,
246 with_check: self.with_check,
247 permissive: self.permissive,
248 }
249 }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
254pub struct TenantPolicy {
255 pub table: String,
257 pub tenant_column: String,
259 pub tenant_source: TenantSource,
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
265pub enum TenantSource {
266 SessionVar(String),
268 SessionContext(String),
270 Function(String),
272}
273
274impl TenantPolicy {
275 pub fn new(
277 table: impl Into<String>,
278 tenant_column: impl Into<String>,
279 source: TenantSource,
280 ) -> Self {
281 Self {
282 table: table.into(),
283 tenant_column: tenant_column.into(),
284 tenant_source: source,
285 }
286 }
287
288 pub fn to_postgres_rls(&self) -> RlsPolicy {
290 let tenant_expr = match &self.tenant_source {
291 TenantSource::SessionVar(var) => format!("current_setting('{}')", var),
292 TenantSource::Function(func) => format!("{}()", func),
293 TenantSource::SessionContext(key) => format!("current_setting('{}')", key),
294 };
295
296 RlsPolicy::new(format!("{}_tenant_isolation", self.table), &self.table)
297 .using(format!("{} = {}::INT", self.tenant_column, tenant_expr))
298 .with_check(format!("{} = {}::INT", self.tenant_column, tenant_expr))
299 .build()
300 }
301
302 pub fn set_tenant_sql(&self, tenant_id: &str, db_type: DatabaseType) -> String {
304 match db_type {
305 DatabaseType::PostgreSQL => match &self.tenant_source {
306 TenantSource::SessionVar(var) => {
307 format!("SET LOCAL {} = '{}'", var, tenant_id)
308 }
309 _ => format!("SELECT set_config('app.tenant_id', '{}', true)", tenant_id),
310 },
311 DatabaseType::MSSQL => {
312 format!("EXEC sp_set_session_context @key = N'tenant_id', @value = {}", tenant_id)
313 }
314 _ => String::new(),
315 }
316 }
317}
318
319#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
325pub struct Role {
326 pub name: String,
328 pub login: bool,
330 pub password: Option<String>,
332 pub inherit_from: Vec<String>,
334 pub superuser: bool,
336 pub createdb: bool,
338 pub createrole: bool,
340 pub connection_limit: Option<i32>,
342 pub valid_until: Option<String>,
344}
345
346impl Role {
347 pub fn new(name: impl Into<String>) -> RoleBuilder {
349 RoleBuilder::new(name)
350 }
351
352 pub fn to_postgres_sql(&self) -> String {
354 let mut sql = format!("CREATE ROLE {}", self.name);
355 let mut options = Vec::new();
356
357 if self.login {
358 options.push("LOGIN".to_string());
359 } else {
360 options.push("NOLOGIN".to_string());
361 }
362
363 if let Some(ref pwd) = self.password {
364 options.push(format!("PASSWORD '{}'", pwd));
365 }
366
367 if self.superuser {
368 options.push("SUPERUSER".to_string());
369 }
370
371 if self.createdb {
372 options.push("CREATEDB".to_string());
373 }
374
375 if self.createrole {
376 options.push("CREATEROLE".to_string());
377 }
378
379 if let Some(limit) = self.connection_limit {
380 options.push(format!("CONNECTION LIMIT {}", limit));
381 }
382
383 if let Some(ref until) = self.valid_until {
384 options.push(format!("VALID UNTIL '{}'", until));
385 }
386
387 if !self.inherit_from.is_empty() {
388 options.push(format!("IN ROLE {}", self.inherit_from.join(", ")));
389 }
390
391 if !options.is_empty() {
392 sql.push_str(" WITH ");
393 sql.push_str(&options.join(" "));
394 }
395
396 sql
397 }
398
399 pub fn to_mysql_sql(&self) -> Vec<String> {
401 let mut sqls = Vec::new();
402
403 if self.login {
404 let mut sql = format!("CREATE USER '{}'@'%'", self.name);
405 if let Some(ref pwd) = self.password {
406 sql.push_str(&format!(" IDENTIFIED BY '{}'", pwd));
407 }
408 sqls.push(sql);
409 } else {
410 sqls.push(format!("CREATE ROLE '{}'", self.name));
411 }
412
413 for parent in &self.inherit_from {
414 sqls.push(format!("GRANT '{}' TO '{}'", parent, self.name));
415 }
416
417 sqls
418 }
419
420 pub fn to_mssql_sql(&self, database: &str) -> Vec<String> {
422 let mut sqls = Vec::new();
423
424 if self.login {
425 let mut sql = format!("CREATE LOGIN {} WITH PASSWORD = ", self.name);
426 if let Some(ref pwd) = self.password {
427 sql.push_str(&format!("'{}'", pwd));
428 } else {
429 sql.push_str("''");
430 }
431 sqls.push(sql);
432 sqls.push(format!("USE {}; CREATE USER {} FOR LOGIN {}", database, self.name, self.name));
433 } else {
434 sqls.push(format!("USE {}; CREATE ROLE {}", database, self.name));
435 }
436
437 for parent in &self.inherit_from {
438 sqls.push(format!("ALTER ROLE {} ADD MEMBER {}", parent, self.name));
439 }
440
441 sqls
442 }
443}
444
445#[derive(Debug, Clone)]
447pub struct RoleBuilder {
448 name: String,
449 login: bool,
450 password: Option<String>,
451 inherit_from: Vec<String>,
452 superuser: bool,
453 createdb: bool,
454 createrole: bool,
455 connection_limit: Option<i32>,
456 valid_until: Option<String>,
457}
458
459impl RoleBuilder {
460 pub fn new(name: impl Into<String>) -> Self {
462 Self {
463 name: name.into(),
464 login: false,
465 password: None,
466 inherit_from: Vec::new(),
467 superuser: false,
468 createdb: false,
469 createrole: false,
470 connection_limit: None,
471 valid_until: None,
472 }
473 }
474
475 pub fn login(mut self) -> Self {
477 self.login = true;
478 self
479 }
480
481 pub fn password(mut self, pwd: impl Into<String>) -> Self {
483 self.password = Some(pwd.into());
484 self.login = true;
485 self
486 }
487
488 pub fn inherit<I, S>(mut self, roles: I) -> Self
490 where
491 I: IntoIterator<Item = S>,
492 S: Into<String>,
493 {
494 self.inherit_from = roles.into_iter().map(Into::into).collect();
495 self
496 }
497
498 pub fn superuser(mut self) -> Self {
500 self.superuser = true;
501 self
502 }
503
504 pub fn createdb(mut self) -> Self {
506 self.createdb = true;
507 self
508 }
509
510 pub fn createrole(mut self) -> Self {
512 self.createrole = true;
513 self
514 }
515
516 pub fn connection_limit(mut self, limit: i32) -> Self {
518 self.connection_limit = Some(limit);
519 self
520 }
521
522 pub fn valid_until(mut self, timestamp: impl Into<String>) -> Self {
524 self.valid_until = Some(timestamp.into());
525 self
526 }
527
528 pub fn build(self) -> Role {
530 Role {
531 name: self.name,
532 login: self.login,
533 password: self.password,
534 inherit_from: self.inherit_from,
535 superuser: self.superuser,
536 createdb: self.createdb,
537 createrole: self.createrole,
538 connection_limit: self.connection_limit,
539 valid_until: self.valid_until,
540 }
541 }
542}
543
544#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
550pub struct Grant {
551 pub privileges: Vec<Privilege>,
553 pub object: GrantObject,
555 pub grantee: String,
557 pub with_grant_option: bool,
559}
560
561#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
563pub enum Privilege {
564 Select,
566 Insert,
568 Update,
570 Delete,
572 Truncate,
574 References,
576 Trigger,
578 All,
580 Execute,
582 Usage,
584 Create,
586 Connect,
588}
589
590impl Privilege {
591 pub fn to_sql(&self) -> &'static str {
593 match self {
594 Self::Select => "SELECT",
595 Self::Insert => "INSERT",
596 Self::Update => "UPDATE",
597 Self::Delete => "DELETE",
598 Self::Truncate => "TRUNCATE",
599 Self::References => "REFERENCES",
600 Self::Trigger => "TRIGGER",
601 Self::All => "ALL PRIVILEGES",
602 Self::Execute => "EXECUTE",
603 Self::Usage => "USAGE",
604 Self::Create => "CREATE",
605 Self::Connect => "CONNECT",
606 }
607 }
608}
609
610#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
612pub enum GrantObject {
613 Table { name: String, columns: Option<Vec<String>> },
615 Schema(String),
617 Database(String),
619 Sequence(String),
621 Function { name: String, args: String },
623 AllTablesInSchema(String),
625 AllSequencesInSchema(String),
627}
628
629impl GrantObject {
630 pub fn table(name: impl Into<String>) -> Self {
632 Self::Table {
633 name: name.into(),
634 columns: None,
635 }
636 }
637
638 pub fn table_columns<I, S>(name: impl Into<String>, columns: I) -> Self
640 where
641 I: IntoIterator<Item = S>,
642 S: Into<String>,
643 {
644 Self::Table {
645 name: name.into(),
646 columns: Some(columns.into_iter().map(Into::into).collect()),
647 }
648 }
649
650 pub fn schema(name: impl Into<String>) -> Self {
652 Self::Schema(name.into())
653 }
654
655 pub fn to_sql(&self) -> String {
657 match self {
658 Self::Table { name, columns } => {
659 if let Some(cols) = columns {
660 format!("TABLE {} ({})", name, cols.join(", "))
661 } else {
662 format!("TABLE {}", name)
663 }
664 }
665 Self::Schema(name) => format!("SCHEMA {}", name),
666 Self::Database(name) => format!("DATABASE {}", name),
667 Self::Sequence(name) => format!("SEQUENCE {}", name),
668 Self::Function { name, args } => format!("FUNCTION {}({})", name, args),
669 Self::AllTablesInSchema(schema) => format!("ALL TABLES IN SCHEMA {}", schema),
670 Self::AllSequencesInSchema(schema) => format!("ALL SEQUENCES IN SCHEMA {}", schema),
671 }
672 }
673}
674
675impl Grant {
676 pub fn new(grantee: impl Into<String>) -> GrantBuilder {
678 GrantBuilder::new(grantee)
679 }
680
681 pub fn to_postgres_sql(&self) -> String {
683 let privs: Vec<&str> = self.privileges.iter().map(Privilege::to_sql).collect();
684 let priv_sql = match &self.object {
685 GrantObject::Table { columns: Some(cols), .. } => {
686 privs
688 .iter()
689 .map(|p| format!("{} ({})", p, cols.join(", ")))
690 .collect::<Vec<_>>()
691 .join(", ")
692 }
693 _ => privs.join(", "),
694 };
695
696 let obj_sql = match &self.object {
697 GrantObject::Table { name, columns: Some(_) } => format!("TABLE {}", name),
698 _ => self.object.to_sql(),
699 };
700
701 let mut sql = format!("GRANT {} ON {} TO {}", priv_sql, obj_sql, self.grantee);
702
703 if self.with_grant_option {
704 sql.push_str(" WITH GRANT OPTION");
705 }
706
707 sql
708 }
709
710 pub fn to_mysql_sql(&self) -> String {
712 let privs: Vec<&str> = self.privileges.iter().map(Privilege::to_sql).collect();
713 let priv_sql = match &self.object {
714 GrantObject::Table { columns: Some(cols), .. } => {
715 privs
716 .iter()
717 .map(|p| format!("{} ({})", p, cols.join(", ")))
718 .collect::<Vec<_>>()
719 .join(", ")
720 }
721 _ => privs.join(", "),
722 };
723
724 let obj = match &self.object {
725 GrantObject::Table { name, .. } => name.clone(),
726 GrantObject::Database(name) => format!("{}.*", name),
727 GrantObject::Schema(name) => format!("{}.*", name),
728 _ => "*.*".to_string(),
729 };
730
731 let mut sql = format!("GRANT {} ON {} TO '{}'@'%'", priv_sql, obj, self.grantee);
732
733 if self.with_grant_option {
734 sql.push_str(" WITH GRANT OPTION");
735 }
736
737 sql
738 }
739
740 pub fn to_mssql_sql(&self) -> String {
742 let privs: Vec<&str> = self.privileges.iter().map(Privilege::to_sql).collect();
743
744 let (obj_type, obj_name) = match &self.object {
745 GrantObject::Table { name, columns } => {
746 if let Some(cols) = columns {
747 return format!(
748 "GRANT {} ({}) ON {} TO {}",
749 privs.join(", "),
750 cols.join(", "),
751 name,
752 self.grantee
753 );
754 }
755 ("OBJECT", name.clone())
756 }
757 GrantObject::Schema(name) => ("SCHEMA", name.clone()),
758 GrantObject::Database(name) => ("DATABASE", name.clone()),
759 _ => ("OBJECT", "".to_string()),
760 };
761
762 format!(
763 "GRANT {} ON {}::{} TO {}",
764 privs.join(", "),
765 obj_type,
766 obj_name,
767 self.grantee
768 )
769 }
770}
771
772#[derive(Debug, Clone)]
774pub struct GrantBuilder {
775 grantee: String,
776 privileges: Vec<Privilege>,
777 object: Option<GrantObject>,
778 with_grant_option: bool,
779}
780
781impl GrantBuilder {
782 pub fn new(grantee: impl Into<String>) -> Self {
784 Self {
785 grantee: grantee.into(),
786 privileges: Vec::new(),
787 object: None,
788 with_grant_option: false,
789 }
790 }
791
792 pub fn privilege(mut self, priv_: Privilege) -> Self {
794 self.privileges.push(priv_);
795 self
796 }
797
798 pub fn select(self) -> Self {
800 self.privilege(Privilege::Select)
801 }
802
803 pub fn insert(self) -> Self {
805 self.privilege(Privilege::Insert)
806 }
807
808 pub fn update(self) -> Self {
810 self.privilege(Privilege::Update)
811 }
812
813 pub fn delete(self) -> Self {
815 self.privilege(Privilege::Delete)
816 }
817
818 pub fn all(self) -> Self {
820 self.privilege(Privilege::All)
821 }
822
823 pub fn on(mut self, object: GrantObject) -> Self {
825 self.object = Some(object);
826 self
827 }
828
829 pub fn on_table(self, table: impl Into<String>) -> Self {
831 self.on(GrantObject::table(table))
832 }
833
834 pub fn on_columns<I, S>(self, table: impl Into<String>, columns: I) -> Self
836 where
837 I: IntoIterator<Item = S>,
838 S: Into<String>,
839 {
840 self.on(GrantObject::table_columns(table, columns))
841 }
842
843 pub fn on_schema(self, schema: impl Into<String>) -> Self {
845 self.on(GrantObject::Schema(schema.into()))
846 }
847
848 pub fn with_grant_option(mut self) -> Self {
850 self.with_grant_option = true;
851 self
852 }
853
854 pub fn build(self) -> QueryResult<Grant> {
856 let object = self.object.ok_or_else(|| {
857 QueryError::invalid_input("object", "Grant requires an object (use on_table, on_schema, etc.)")
858 })?;
859
860 if self.privileges.is_empty() {
861 return Err(QueryError::invalid_input(
862 "privileges",
863 "Grant requires at least one privilege",
864 ));
865 }
866
867 Ok(Grant {
868 privileges: self.privileges,
869 object,
870 grantee: self.grantee,
871 with_grant_option: self.with_grant_option,
872 })
873 }
874}
875
876#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
882pub struct DataMask {
883 pub table: String,
885 pub column: String,
887 pub mask_function: MaskFunction,
889}
890
891#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
893pub enum MaskFunction {
894 Default,
896 Email,
898 Partial { prefix: usize, padding: String, suffix: usize },
900 Random,
902 Custom(String),
904 Null,
906}
907
908impl DataMask {
909 pub fn new(table: impl Into<String>, column: impl Into<String>, mask: MaskFunction) -> Self {
911 Self {
912 table: table.into(),
913 column: column.into(),
914 mask_function: mask,
915 }
916 }
917
918 pub fn to_postgres_view(&self, view_name: &str) -> String {
920 let masked_expr = match &self.mask_function {
921 MaskFunction::Default => format!(
922 "CASE WHEN current_user = 'admin' THEN {} ELSE '****' END",
923 self.column
924 ),
925 MaskFunction::Email => format!(
926 "CASE WHEN current_user = 'admin' THEN {} ELSE \
927 CONCAT(LEFT({}, 1), '***@', SPLIT_PART({}, '@', 2)) END",
928 self.column, self.column, self.column
929 ),
930 MaskFunction::Partial { prefix, padding, suffix } => format!(
931 "CONCAT(LEFT({}, {}), '{}', RIGHT({}, {}))",
932 self.column, prefix, padding, self.column, suffix
933 ),
934 MaskFunction::Null => "NULL".to_string(),
935 MaskFunction::Custom(func) => format!("{}({})", func, self.column),
936 MaskFunction::Random => format!("md5(random()::text)"),
937 };
938
939 format!(
940 "CREATE OR REPLACE VIEW {} AS SELECT *, {} AS {}_masked FROM {}",
941 view_name, masked_expr, self.column, self.table
942 )
943 }
944
945 pub fn to_mssql_alter(&self) -> String {
947 let mask_func = match &self.mask_function {
948 MaskFunction::Default => "default()".to_string(),
949 MaskFunction::Email => "email()".to_string(),
950 MaskFunction::Partial { prefix, padding, suffix } => {
951 format!("partial({}, '{}', {})", prefix, padding, suffix)
952 }
953 MaskFunction::Random => "random(1, 100)".to_string(),
954 MaskFunction::Custom(func) => func.clone(),
955 MaskFunction::Null => "default()".to_string(),
956 };
957
958 format!(
959 "ALTER TABLE {} ALTER COLUMN {} ADD MASKED WITH (FUNCTION = '{}')",
960 self.table, self.column, mask_func
961 )
962 }
963}
964
965#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
971pub struct ConnectionProfile {
972 pub name: String,
974 pub role: String,
976 pub search_path: Vec<String>,
978 pub session_vars: Vec<(String, String)>,
980 pub read_only: bool,
982 pub statement_timeout: Option<u32>,
984 pub lock_timeout: Option<u32>,
986}
987
988impl ConnectionProfile {
989 pub fn new(name: impl Into<String>, role: impl Into<String>) -> ConnectionProfileBuilder {
991 ConnectionProfileBuilder::new(name, role)
992 }
993
994 pub fn to_postgres_setup(&self) -> Vec<String> {
996 let mut sqls = Vec::new();
997
998 sqls.push(format!("SET ROLE {}", self.role));
999
1000 if !self.search_path.is_empty() {
1001 sqls.push(format!("SET search_path TO {}", self.search_path.join(", ")));
1002 }
1003
1004 if self.read_only {
1005 sqls.push("SET default_transaction_read_only = ON".to_string());
1006 }
1007
1008 if let Some(timeout) = self.statement_timeout {
1009 sqls.push(format!("SET statement_timeout = {}", timeout));
1010 }
1011
1012 if let Some(timeout) = self.lock_timeout {
1013 sqls.push(format!("SET lock_timeout = {}", timeout));
1014 }
1015
1016 for (key, value) in &self.session_vars {
1017 sqls.push(format!("SET {} = '{}'", key, value));
1018 }
1019
1020 sqls
1021 }
1022
1023 pub fn to_mysql_setup(&self) -> Vec<String> {
1025 let mut sqls = Vec::new();
1026
1027 if self.read_only {
1029 sqls.push("SET SESSION TRANSACTION READ ONLY".to_string());
1030 }
1031
1032 for (key, value) in &self.session_vars {
1033 sqls.push(format!("SET @{} = '{}'", key, value));
1034 }
1035
1036 sqls
1037 }
1038}
1039
1040#[derive(Debug, Clone)]
1042pub struct ConnectionProfileBuilder {
1043 name: String,
1044 role: String,
1045 search_path: Vec<String>,
1046 session_vars: Vec<(String, String)>,
1047 read_only: bool,
1048 statement_timeout: Option<u32>,
1049 lock_timeout: Option<u32>,
1050}
1051
1052impl ConnectionProfileBuilder {
1053 pub fn new(name: impl Into<String>, role: impl Into<String>) -> Self {
1055 Self {
1056 name: name.into(),
1057 role: role.into(),
1058 search_path: Vec::new(),
1059 session_vars: Vec::new(),
1060 read_only: false,
1061 statement_timeout: None,
1062 lock_timeout: None,
1063 }
1064 }
1065
1066 pub fn search_path<I, S>(mut self, schemas: I) -> Self
1068 where
1069 I: IntoIterator<Item = S>,
1070 S: Into<String>,
1071 {
1072 self.search_path = schemas.into_iter().map(Into::into).collect();
1073 self
1074 }
1075
1076 pub fn session_var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1078 self.session_vars.push((key.into(), value.into()));
1079 self
1080 }
1081
1082 pub fn read_only(mut self) -> Self {
1084 self.read_only = true;
1085 self
1086 }
1087
1088 pub fn statement_timeout(mut self, ms: u32) -> Self {
1090 self.statement_timeout = Some(ms);
1091 self
1092 }
1093
1094 pub fn lock_timeout(mut self, ms: u32) -> Self {
1096 self.lock_timeout = Some(ms);
1097 self
1098 }
1099
1100 pub fn build(self) -> ConnectionProfile {
1102 ConnectionProfile {
1103 name: self.name,
1104 role: self.role,
1105 search_path: self.search_path,
1106 session_vars: self.session_vars,
1107 read_only: self.read_only,
1108 statement_timeout: self.statement_timeout,
1109 lock_timeout: self.lock_timeout,
1110 }
1111 }
1112}
1113
1114pub mod mongodb {
1120 use serde::{Deserialize, Serialize};
1121 use serde_json::Value as JsonValue;
1122
1123 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1125 pub struct MongoRole {
1126 pub role: String,
1128 pub db: String,
1130 pub privileges: Vec<MongoPrivilege>,
1132 pub roles: Vec<InheritedRole>,
1134 }
1135
1136 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1138 pub struct MongoPrivilege {
1139 pub resource: MongoResource,
1141 pub actions: Vec<String>,
1143 }
1144
1145 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1147 #[serde(untagged)]
1148 pub enum MongoResource {
1149 Collection { db: String, collection: String },
1151 Database { db: String },
1153 Cluster { cluster: bool },
1155 }
1156
1157 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1159 pub struct InheritedRole {
1160 pub role: String,
1162 pub db: String,
1164 }
1165
1166 impl MongoRole {
1167 pub fn new(role: impl Into<String>, db: impl Into<String>) -> MongoRoleBuilder {
1169 MongoRoleBuilder::new(role, db)
1170 }
1171
1172 pub fn to_create_command(&self) -> JsonValue {
1174 let privileges: Vec<JsonValue> = self
1175 .privileges
1176 .iter()
1177 .map(|p| {
1178 let resource = match &p.resource {
1179 MongoResource::Collection { db, collection } => {
1180 serde_json::json!({ "db": db, "collection": collection })
1181 }
1182 MongoResource::Database { db } => {
1183 serde_json::json!({ "db": db, "collection": "" })
1184 }
1185 MongoResource::Cluster { .. } => {
1186 serde_json::json!({ "cluster": true })
1187 }
1188 };
1189 serde_json::json!({
1190 "resource": resource,
1191 "actions": p.actions
1192 })
1193 })
1194 .collect();
1195
1196 let roles: Vec<JsonValue> = self
1197 .roles
1198 .iter()
1199 .map(|r| serde_json::json!({ "role": r.role, "db": r.db }))
1200 .collect();
1201
1202 serde_json::json!({
1203 "createRole": self.role,
1204 "privileges": privileges,
1205 "roles": roles
1206 })
1207 }
1208 }
1209
1210 #[derive(Debug, Clone, Default)]
1212 pub struct MongoRoleBuilder {
1213 role: String,
1214 db: String,
1215 privileges: Vec<MongoPrivilege>,
1216 roles: Vec<InheritedRole>,
1217 }
1218
1219 impl MongoRoleBuilder {
1220 pub fn new(role: impl Into<String>, db: impl Into<String>) -> Self {
1222 Self {
1223 role: role.into(),
1224 db: db.into(),
1225 privileges: Vec::new(),
1226 roles: Vec::new(),
1227 }
1228 }
1229
1230 pub fn privilege_collection<I, S>(
1232 mut self,
1233 collection: impl Into<String>,
1234 actions: I,
1235 ) -> Self
1236 where
1237 I: IntoIterator<Item = S>,
1238 S: Into<String>,
1239 {
1240 self.privileges.push(MongoPrivilege {
1241 resource: MongoResource::Collection {
1242 db: self.db.clone(),
1243 collection: collection.into(),
1244 },
1245 actions: actions.into_iter().map(Into::into).collect(),
1246 });
1247 self
1248 }
1249
1250 pub fn privilege_database<I, S>(mut self, actions: I) -> Self
1252 where
1253 I: IntoIterator<Item = S>,
1254 S: Into<String>,
1255 {
1256 self.privileges.push(MongoPrivilege {
1257 resource: MongoResource::Database { db: self.db.clone() },
1258 actions: actions.into_iter().map(Into::into).collect(),
1259 });
1260 self
1261 }
1262
1263 pub fn inherit(mut self, role: impl Into<String>, db: impl Into<String>) -> Self {
1265 self.roles.push(InheritedRole {
1266 role: role.into(),
1267 db: db.into(),
1268 });
1269 self
1270 }
1271
1272 pub fn build(self) -> MongoRole {
1274 MongoRole {
1275 role: self.role,
1276 db: self.db,
1277 privileges: self.privileges,
1278 roles: self.roles,
1279 }
1280 }
1281 }
1282
1283 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1285 pub struct FieldEncryption {
1286 pub key_vault_namespace: String,
1288 pub kms_providers: KmsProviders,
1290 pub schema_map: serde_json::Map<String, JsonValue>,
1292 }
1293
1294 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1296 pub enum KmsProviders {
1297 Local { key: String },
1299 Aws {
1301 access_key_id: String,
1302 secret_access_key: String,
1303 region: String,
1304 },
1305 Azure {
1307 tenant_id: String,
1308 client_id: String,
1309 client_secret: String,
1310 },
1311 Gcp {
1313 email: String,
1314 private_key: String,
1315 },
1316 }
1317
1318 impl FieldEncryption {
1319 pub fn new(key_vault_namespace: impl Into<String>, kms: KmsProviders) -> Self {
1321 Self {
1322 key_vault_namespace: key_vault_namespace.into(),
1323 kms_providers: kms,
1324 schema_map: serde_json::Map::new(),
1325 }
1326 }
1327
1328 pub fn encrypt_field(
1330 mut self,
1331 namespace: impl Into<String>,
1332 field: impl Into<String>,
1333 algorithm: EncryptionAlgorithm,
1334 key_id: impl Into<String>,
1335 ) -> Self {
1336 let ns = namespace.into();
1337 let field = field.into();
1338
1339 let field_spec = serde_json::json!({
1340 "encrypt": {
1341 "bsonType": "string",
1342 "algorithm": algorithm.to_str(),
1343 "keyId": [{ "$binary": { "base64": key_id.into(), "subType": "04" } }]
1344 }
1345 });
1346
1347 let schema = self.schema_map.entry(ns).or_insert_with(|| {
1349 serde_json::json!({
1350 "bsonType": "object",
1351 "properties": {}
1352 })
1353 });
1354
1355 if let Some(props) = schema.get_mut("properties").and_then(|p| p.as_object_mut()) {
1356 props.insert(field, field_spec);
1357 }
1358
1359 self
1360 }
1361
1362 pub fn to_options(&self) -> JsonValue {
1364 let kms = match &self.kms_providers {
1365 KmsProviders::Local { key } => {
1366 serde_json::json!({ "local": { "key": key } })
1367 }
1368 KmsProviders::Aws { access_key_id, secret_access_key, region } => {
1369 serde_json::json!({
1370 "aws": {
1371 "accessKeyId": access_key_id,
1372 "secretAccessKey": secret_access_key,
1373 "region": region
1374 }
1375 })
1376 }
1377 KmsProviders::Azure { tenant_id, client_id, client_secret } => {
1378 serde_json::json!({
1379 "azure": {
1380 "tenantId": tenant_id,
1381 "clientId": client_id,
1382 "clientSecret": client_secret
1383 }
1384 })
1385 }
1386 KmsProviders::Gcp { email, private_key } => {
1387 serde_json::json!({
1388 "gcp": {
1389 "email": email,
1390 "privateKey": private_key
1391 }
1392 })
1393 }
1394 };
1395
1396 serde_json::json!({
1397 "keyVaultNamespace": self.key_vault_namespace,
1398 "kmsProviders": kms,
1399 "schemaMap": self.schema_map
1400 })
1401 }
1402 }
1403
1404 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1406 pub enum EncryptionAlgorithm {
1407 Deterministic,
1409 Random,
1411 }
1412
1413 impl EncryptionAlgorithm {
1414 pub fn to_str(&self) -> &'static str {
1416 match self {
1417 Self::Deterministic => "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic",
1418 Self::Random => "AEAD_AES_256_CBC_HMAC_SHA_512-Random",
1419 }
1420 }
1421 }
1422}
1423
1424#[cfg(test)]
1425mod tests {
1426 use super::*;
1427
1428 #[test]
1429 fn test_rls_policy_postgres() {
1430 let policy = RlsPolicy::new("tenant_isolation", "orders")
1431 .using("tenant_id = current_setting('app.tenant_id')::INT")
1432 .with_check("tenant_id = current_setting('app.tenant_id')::INT")
1433 .build();
1434
1435 let sql = policy.to_postgres_sql();
1436 assert!(sql.contains("CREATE POLICY tenant_isolation ON orders"));
1437 assert!(sql.contains("USING (tenant_id ="));
1438 assert!(sql.contains("WITH CHECK (tenant_id ="));
1439 }
1440
1441 #[test]
1442 fn test_rls_policy_for_select() {
1443 let policy = RlsPolicy::new("read_own", "documents")
1444 .for_select()
1445 .to_roles(["app_user"])
1446 .using("owner_id = current_user_id()")
1447 .build();
1448
1449 let sql = policy.to_postgres_sql();
1450 assert!(sql.contains("FOR SELECT"));
1451 assert!(sql.contains("TO app_user"));
1452 }
1453
1454 #[test]
1455 fn test_tenant_policy() {
1456 let tenant = TenantPolicy::new(
1457 "orders",
1458 "tenant_id",
1459 TenantSource::SessionVar("app.tenant_id".to_string()),
1460 );
1461
1462 let policy = tenant.to_postgres_rls();
1463 assert!(policy.using.is_some());
1464 assert!(policy.with_check.is_some());
1465
1466 let set_sql = tenant.set_tenant_sql("123", DatabaseType::PostgreSQL);
1467 assert!(set_sql.contains("SET LOCAL app.tenant_id"));
1468 }
1469
1470 #[test]
1471 fn test_role_postgres() {
1472 let role = Role::new("app_reader")
1473 .login()
1474 .password("secret")
1475 .connection_limit(10)
1476 .build();
1477
1478 let sql = role.to_postgres_sql();
1479 assert!(sql.contains("CREATE ROLE app_reader"));
1480 assert!(sql.contains("LOGIN"));
1481 assert!(sql.contains("PASSWORD 'secret'"));
1482 assert!(sql.contains("CONNECTION LIMIT 10"));
1483 }
1484
1485 #[test]
1486 fn test_role_inherit() {
1487 let role = Role::new("senior_dev")
1488 .inherit(["developer", "analyst"])
1489 .build();
1490
1491 let sql = role.to_postgres_sql();
1492 assert!(sql.contains("IN ROLE developer, analyst"));
1493 }
1494
1495 #[test]
1496 fn test_grant_table() {
1497 let grant = Grant::new("app_user")
1498 .select()
1499 .insert()
1500 .update()
1501 .on_table("users")
1502 .build()
1503 .unwrap();
1504
1505 let sql = grant.to_postgres_sql();
1506 assert!(sql.contains("GRANT SELECT, INSERT, UPDATE ON TABLE users TO app_user"));
1507 }
1508
1509 #[test]
1510 fn test_grant_columns() {
1511 let grant = Grant::new("restricted_user")
1512 .select()
1513 .on_columns("users", ["id", "name", "email"])
1514 .build()
1515 .unwrap();
1516
1517 let sql = grant.to_postgres_sql();
1518 assert!(sql.contains("SELECT (id, name, email)"));
1519 }
1520
1521 #[test]
1522 fn test_grant_with_option() {
1523 let grant = Grant::new("admin")
1524 .all()
1525 .on_schema("public")
1526 .with_grant_option()
1527 .build()
1528 .unwrap();
1529
1530 let sql = grant.to_postgres_sql();
1531 assert!(sql.contains("WITH GRANT OPTION"));
1532 }
1533
1534 #[test]
1535 fn test_data_mask_email() {
1536 let mask = DataMask::new("users", "email", MaskFunction::Email);
1537 let sql = mask.to_mssql_alter();
1538
1539 assert!(sql.contains("ADD MASKED WITH (FUNCTION = 'email()'"));
1540 }
1541
1542 #[test]
1543 fn test_data_mask_partial() {
1544 let mask = DataMask::new(
1545 "users",
1546 "ssn",
1547 MaskFunction::Partial {
1548 prefix: 0,
1549 padding: "XXX-XX-".to_string(),
1550 suffix: 4,
1551 },
1552 );
1553 let sql = mask.to_mssql_alter();
1554
1555 assert!(sql.contains("partial(0, 'XXX-XX-', 4)"));
1556 }
1557
1558 #[test]
1559 fn test_connection_profile() {
1560 let profile = ConnectionProfile::new("readonly_user", "app_readonly")
1561 .search_path(["app", "public"])
1562 .read_only()
1563 .statement_timeout(5000)
1564 .build();
1565
1566 let sqls = profile.to_postgres_setup();
1567 assert!(sqls.iter().any(|s| s.contains("SET ROLE app_readonly")));
1568 assert!(sqls.iter().any(|s| s.contains("search_path TO app, public")));
1569 assert!(sqls.iter().any(|s| s.contains("read_only = ON")));
1570 assert!(sqls.iter().any(|s| s.contains("statement_timeout = 5000")));
1571 }
1572
1573 mod mongodb_tests {
1574 use super::super::mongodb::*;
1575
1576 #[test]
1577 fn test_mongo_role() {
1578 let role = MongoRole::new("app_reader", "mydb")
1579 .privilege_collection("orders", ["find", "aggregate"])
1580 .inherit("read", "mydb")
1581 .build();
1582
1583 let cmd = role.to_create_command();
1584 assert_eq!(cmd["createRole"], "app_reader");
1585 assert!(cmd["privileges"].is_array());
1586 assert!(cmd["roles"].is_array());
1587 }
1588
1589 #[test]
1590 fn test_field_encryption_local() {
1591 let enc = FieldEncryption::new(
1592 "encryption.__keyVault",
1593 KmsProviders::Local {
1594 key: "base64key".to_string(),
1595 },
1596 )
1597 .encrypt_field("mydb.users", "ssn", EncryptionAlgorithm::Deterministic, "keyid");
1598
1599 let opts = enc.to_options();
1600 assert!(opts["kmsProviders"]["local"].is_object());
1601 assert!(opts["schemaMap"]["mydb.users"].is_object());
1602 }
1603
1604 #[test]
1605 fn test_field_encryption_aws() {
1606 let enc = FieldEncryption::new(
1607 "encryption.__keyVault",
1608 KmsProviders::Aws {
1609 access_key_id: "AKID".to_string(),
1610 secret_access_key: "secret".to_string(),
1611 region: "us-east-1".to_string(),
1612 },
1613 );
1614
1615 let opts = enc.to_options();
1616 assert!(opts["kmsProviders"]["aws"]["accessKeyId"].is_string());
1617 }
1618 }
1619}
1620
1621
1622
1623