use std::collections::HashSet;
#[derive(Debug, Clone)]
pub enum IsolationStrategy {
RowLevel(RowLevelConfig),
Schema(SchemaConfig),
Database(DatabaseConfig),
Hybrid(Box<IsolationStrategy>, Box<IsolationStrategy>),
}
impl IsolationStrategy {
pub fn row_level(column: impl Into<String>) -> Self {
Self::RowLevel(RowLevelConfig::new(column))
}
pub fn schema_based() -> Self {
Self::Schema(SchemaConfig::default())
}
pub fn database_based() -> Self {
Self::Database(DatabaseConfig::default())
}
pub fn is_row_level(&self) -> bool {
matches!(self, Self::RowLevel(_))
}
pub fn is_schema_based(&self) -> bool {
matches!(self, Self::Schema(_))
}
pub fn is_database_based(&self) -> bool {
matches!(self, Self::Database(_))
}
pub fn row_level_config(&self) -> Option<&RowLevelConfig> {
match self {
Self::RowLevel(config) => Some(config),
Self::Hybrid(a, b) => a.row_level_config().or_else(|| b.row_level_config()),
_ => None,
}
}
pub fn schema_config(&self) -> Option<&SchemaConfig> {
match self {
Self::Schema(config) => Some(config),
Self::Hybrid(a, b) => a.schema_config().or_else(|| b.schema_config()),
_ => None,
}
}
pub fn database_config(&self) -> Option<&DatabaseConfig> {
match self {
Self::Database(config) => Some(config),
Self::Hybrid(a, b) => a.database_config().or_else(|| b.database_config()),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct RowLevelConfig {
pub column: String,
pub column_type: ColumnType,
pub excluded_tables: HashSet<String>,
pub shared_tables: HashSet<String>,
pub auto_insert: bool,
pub validate_writes: bool,
pub use_database_rls: bool,
}
impl RowLevelConfig {
pub fn new(column: impl Into<String>) -> Self {
Self {
column: column.into(),
column_type: ColumnType::String,
excluded_tables: HashSet::new(),
shared_tables: HashSet::new(),
auto_insert: true,
validate_writes: true,
use_database_rls: false,
}
}
pub fn with_column_type(mut self, column_type: ColumnType) -> Self {
self.column_type = column_type;
self
}
pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
self.excluded_tables.insert(table.into());
self
}
pub fn shared_table(mut self, table: impl Into<String>) -> Self {
self.shared_tables.insert(table.into());
self
}
pub fn without_auto_insert(mut self) -> Self {
self.auto_insert = false;
self
}
pub fn without_write_validation(mut self) -> Self {
self.validate_writes = false;
self
}
pub fn with_database_rls(mut self) -> Self {
self.use_database_rls = true;
self
}
pub fn should_filter(&self, table: &str) -> bool {
!self.excluded_tables.contains(table) && !self.shared_tables.contains(table)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ColumnType {
#[default]
String,
Uuid,
Integer,
BigInt,
}
impl ColumnType {
pub fn placeholder(&self, index: usize) -> String {
format!("${}", index)
}
pub fn format_value(&self, value: &str) -> String {
match self {
Self::String => format!("'{}'", value.replace('\'', "''")),
Self::Uuid => format!("'{}'::uuid", value),
Self::Integer | Self::BigInt => value.to_string(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SchemaConfig {
pub schema_prefix: Option<String>,
pub schema_suffix: Option<String>,
pub shared_schema: Option<String>,
pub auto_create: bool,
pub default_schema: Option<String>,
pub search_path_format: SearchPathFormat,
}
impl SchemaConfig {
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.schema_prefix = Some(prefix.into());
self
}
pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
self.schema_suffix = Some(suffix.into());
self
}
pub fn with_shared_schema(mut self, schema: impl Into<String>) -> Self {
self.shared_schema = Some(schema.into());
self
}
pub fn with_auto_create(mut self) -> Self {
self.auto_create = true;
self
}
pub fn with_default_schema(mut self, schema: impl Into<String>) -> Self {
self.default_schema = Some(schema.into());
self
}
pub fn with_search_path(mut self, format: SearchPathFormat) -> Self {
self.search_path_format = format;
self
}
pub fn schema_name(&self, tenant_id: &str) -> String {
let mut name = String::new();
if let Some(prefix) = &self.schema_prefix {
name.push_str(prefix);
}
name.push_str(tenant_id);
if let Some(suffix) = &self.schema_suffix {
name.push_str(suffix);
}
name
}
pub fn search_path(&self, tenant_id: &str) -> String {
let tenant_schema = self.schema_name(tenant_id);
match self.search_path_format {
SearchPathFormat::TenantOnly => {
format!("SET search_path TO {}", tenant_schema)
}
SearchPathFormat::TenantFirst => {
if let Some(shared) = &self.shared_schema {
format!("SET search_path TO {}, {}", tenant_schema, shared)
} else {
format!("SET search_path TO {}, public", tenant_schema)
}
}
SearchPathFormat::SharedFirst => {
if let Some(shared) = &self.shared_schema {
format!("SET search_path TO {}, {}", shared, tenant_schema)
} else {
format!("SET search_path TO public, {}", tenant_schema)
}
}
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum SearchPathFormat {
TenantOnly,
#[default]
TenantFirst,
SharedFirst,
}
#[derive(Debug, Clone, Default)]
pub struct DatabaseConfig {
pub database_prefix: Option<String>,
pub database_suffix: Option<String>,
pub auto_create: bool,
pub template_database: Option<String>,
pub pool_size_per_tenant: usize,
pub max_tenant_connections: usize,
}
impl DatabaseConfig {
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.database_prefix = Some(prefix.into());
self
}
pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
self.database_suffix = Some(suffix.into());
self
}
pub fn with_auto_create(mut self) -> Self {
self.auto_create = true;
self
}
pub fn with_template(mut self, template: impl Into<String>) -> Self {
self.template_database = Some(template.into());
self
}
pub fn with_pool_size(mut self, size: usize) -> Self {
self.pool_size_per_tenant = size;
self
}
pub fn with_max_connections(mut self, max: usize) -> Self {
self.max_tenant_connections = max;
self
}
pub fn database_name(&self, tenant_id: &str) -> String {
let mut name = String::new();
if let Some(prefix) = &self.database_prefix {
name.push_str(prefix);
}
name.push_str(tenant_id);
if let Some(suffix) = &self.database_suffix {
name.push_str(suffix);
}
name
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_level_config() {
let config = RowLevelConfig::new("tenant_id")
.with_column_type(ColumnType::Uuid)
.exclude_table("audit_logs")
.shared_table("plans");
assert_eq!(config.column, "tenant_id");
assert_eq!(config.column_type, ColumnType::Uuid);
assert!(config.should_filter("users"));
assert!(!config.should_filter("audit_logs"));
assert!(!config.should_filter("plans"));
}
#[test]
fn test_schema_config() {
let config = SchemaConfig::default()
.with_prefix("tenant_")
.with_shared_schema("shared");
assert_eq!(config.schema_name("acme"), "tenant_acme");
assert!(config.search_path("acme").contains("tenant_acme"));
assert!(config.search_path("acme").contains("shared"));
}
#[test]
fn test_database_config() {
let config = DatabaseConfig::default()
.with_prefix("prax_")
.with_suffix("_db");
assert_eq!(config.database_name("acme"), "prax_acme_db");
}
#[test]
fn test_column_type_format() {
assert_eq!(ColumnType::String.format_value("test"), "'test'");
assert_eq!(
ColumnType::Uuid.format_value("123e4567-e89b-12d3-a456-426614174000"),
"'123e4567-e89b-12d3-a456-426614174000'::uuid"
);
assert_eq!(ColumnType::Integer.format_value("42"), "42");
}
}