use serde::{Deserialize, Serialize};
use crate::tenant::TenantId;
use super::{TenantResolution, TenantResolver, TenantValidationError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaPerTenantConfig {
#[serde(default = "default_schema_prefix")]
pub schema_prefix: String,
#[serde(default = "default_shared_schema")]
pub shared_schema: String,
#[serde(default = "default_true")]
pub auto_create_schema: bool,
#[serde(default = "default_true")]
pub system_uses_public: bool,
#[serde(default = "default_max_schema_length")]
pub max_schema_length: usize,
#[serde(default = "default_schema_pattern")]
pub schema_pattern: String,
#[serde(default)]
pub drop_on_delete: bool,
pub template_schema: Option<String>,
}
fn default_schema_prefix() -> String {
"tenant_".to_string()
}
fn default_shared_schema() -> String {
"shared".to_string()
}
fn default_true() -> bool {
true
}
fn default_max_schema_length() -> usize {
63 }
fn default_schema_pattern() -> String {
r"^[a-z][a-z0-9_]*$".to_string()
}
impl Default for SchemaPerTenantConfig {
fn default() -> Self {
Self {
schema_prefix: default_schema_prefix(),
shared_schema: default_shared_schema(),
auto_create_schema: true,
system_uses_public: true,
max_schema_length: default_max_schema_length(),
schema_pattern: default_schema_pattern(),
drop_on_delete: false,
template_schema: None,
}
}
}
impl SchemaPerTenantConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.schema_prefix = prefix.into();
self
}
pub fn with_shared_schema(mut self, schema: impl Into<String>) -> Self {
self.shared_schema = schema.into();
self
}
pub fn with_template(mut self, template: impl Into<String>) -> Self {
self.template_schema = Some(template.into());
self
}
pub fn with_drop_on_delete(mut self) -> Self {
self.drop_on_delete = true;
self
}
}
#[derive(Debug, Clone)]
pub struct SchemaPerTenantStrategy {
config: SchemaPerTenantConfig,
schema_pattern: regex::Regex,
}
impl SchemaPerTenantStrategy {
pub fn new(config: SchemaPerTenantConfig) -> Result<Self, regex::Error> {
let schema_pattern = regex::Regex::new(&config.schema_pattern)?;
Ok(Self {
config,
schema_pattern,
})
}
pub fn config(&self) -> &SchemaPerTenantConfig {
&self.config
}
pub fn shared_schema(&self) -> &str {
&self.config.shared_schema
}
pub fn tenant_to_schema(&self, tenant_id: &TenantId) -> String {
let normalized = self.normalize_tenant_id(tenant_id.as_str());
format!("{}{}", self.config.schema_prefix, normalized)
}
fn normalize_tenant_id(&self, id: &str) -> String {
id.to_lowercase()
.replace(['/', '-'], "_")
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '_')
.collect()
}
pub fn set_search_path_sql(&self, tenant_id: &TenantId) -> String {
let schema = self.tenant_to_schema(tenant_id);
format!(
"SET search_path TO {}, {}, public",
self.escape_identifier(&schema),
self.escape_identifier(&self.config.shared_schema)
)
}
pub fn set_system_search_path_sql(&self) -> String {
if self.config.system_uses_public {
format!(
"SET search_path TO {}, public",
self.escape_identifier(&self.config.shared_schema)
)
} else {
format!(
"SET search_path TO {}",
self.escape_identifier(&self.config.shared_schema)
)
}
}
pub fn reset_search_path_sql(&self) -> String {
"RESET search_path".to_string()
}
pub fn create_schema_sql(&self, tenant_id: &TenantId) -> String {
let schema = self.tenant_to_schema(tenant_id);
if let Some(ref template) = self.config.template_schema {
format!(
"CREATE SCHEMA IF NOT EXISTS {} TEMPLATE {}",
self.escape_identifier(&schema),
self.escape_identifier(template)
)
} else {
format!(
"CREATE SCHEMA IF NOT EXISTS {}",
self.escape_identifier(&schema)
)
}
}
pub fn drop_schema_sql(&self, tenant_id: &TenantId, cascade: bool) -> String {
let schema = self.tenant_to_schema(tenant_id);
let cascade_str = if cascade { " CASCADE" } else { "" };
format!(
"DROP SCHEMA IF EXISTS {}{}",
self.escape_identifier(&schema),
cascade_str
)
}
pub fn schema_exists_sql(&self, tenant_id: &TenantId) -> String {
let schema = self.tenant_to_schema(tenant_id);
format!(
"SELECT EXISTS(SELECT 1 FROM information_schema.schemata WHERE schema_name = '{}')",
self.escape_sql_string(&schema)
)
}
pub fn list_tenant_schemas_sql(&self) -> String {
format!(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name LIKE '{}%' ORDER BY schema_name",
self.escape_sql_string(&self.config.schema_prefix)
)
}
fn escape_identifier(&self, id: &str) -> String {
format!("\"{}\"", id.replace('"', "\"\""))
}
fn escape_sql_string(&self, s: &str) -> String {
s.replace('\'', "''")
}
fn validate_schema_name(&self, schema: &str) -> Result<(), TenantValidationError> {
if schema.len() > self.config.max_schema_length {
return Err(TenantValidationError {
tenant_id: schema.to_string(),
reason: format!(
"schema name exceeds maximum length of {} characters",
self.config.max_schema_length
),
});
}
if !self.schema_pattern.is_match(schema) {
return Err(TenantValidationError {
tenant_id: schema.to_string(),
reason: format!(
"schema name does not match required pattern: {}",
self.config.schema_pattern
),
});
}
Ok(())
}
}
impl TenantResolver for SchemaPerTenantStrategy {
fn resolve(&self, tenant_id: &TenantId) -> TenantResolution {
TenantResolution::Schema {
schema_name: self.tenant_to_schema(tenant_id),
}
}
fn validate(&self, tenant_id: &TenantId) -> Result<(), TenantValidationError> {
let schema = self.tenant_to_schema(tenant_id);
self.validate_schema_name(&schema)
}
fn system_tenant(&self) -> TenantResolution {
TenantResolution::Schema {
schema_name: self.config.shared_schema.clone(),
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct SchemaManager<'a> {
strategy: &'a SchemaPerTenantStrategy,
}
#[allow(dead_code)]
impl<'a> SchemaManager<'a> {
pub fn new(strategy: &'a SchemaPerTenantStrategy) -> Self {
Self { strategy }
}
pub fn create_shared_schema_ddl(&self) -> String {
format!(
"CREATE SCHEMA IF NOT EXISTS {}",
self.strategy
.escape_identifier(&self.strategy.config.shared_schema)
)
}
#[allow(dead_code)]
pub fn create_table_ddl(&self, schema: &str, table_ddl: &str) -> String {
format!(
"SET search_path TO {};\n{}",
self.strategy.escape_identifier(schema),
table_ddl
)
}
#[allow(dead_code)]
pub fn migrate_all_schemas_sql(&self, migration_sql: &str) -> String {
format!(
r#"
DO $$
DECLARE
schema_name TEXT;
BEGIN
FOR schema_name IN
SELECT s.schema_name
FROM information_schema.schemata s
WHERE s.schema_name LIKE '{}%'
LOOP
EXECUTE format('SET search_path TO %I', schema_name);
{}
END LOOP;
END $$;
"#,
self.strategy
.escape_sql_string(&self.strategy.config.schema_prefix),
migration_sql.replace('\'', "''")
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schema_per_tenant_config_default() {
let config = SchemaPerTenantConfig::default();
assert_eq!(config.schema_prefix, "tenant_");
assert_eq!(config.shared_schema, "shared");
assert!(config.auto_create_schema);
}
#[test]
fn test_schema_per_tenant_config_builder() {
let config = SchemaPerTenantConfig::new()
.with_prefix("org_")
.with_shared_schema("common")
.with_template("template_tenant")
.with_drop_on_delete();
assert_eq!(config.schema_prefix, "org_");
assert_eq!(config.shared_schema, "common");
assert_eq!(config.template_schema, Some("template_tenant".to_string()));
assert!(config.drop_on_delete);
}
#[test]
fn test_tenant_to_schema() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
assert_eq!(
strategy.tenant_to_schema(&TenantId::new("acme")),
"tenant_acme"
);
assert_eq!(
strategy.tenant_to_schema(&TenantId::new("Acme-Corp")),
"tenant_acme_corp"
);
assert_eq!(
strategy.tenant_to_schema(&TenantId::new("acme/research")),
"tenant_acme_research"
);
}
#[test]
fn test_tenant_resolution() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let resolution = strategy.resolve(&TenantId::new("acme"));
match resolution {
TenantResolution::Schema { schema_name } => {
assert_eq!(schema_name, "tenant_acme");
}
_ => panic!("expected Schema resolution"),
}
}
#[test]
fn test_set_search_path_sql() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let sql = strategy.set_search_path_sql(&TenantId::new("acme"));
assert_eq!(
sql,
"SET search_path TO \"tenant_acme\", \"shared\", public"
);
}
#[test]
fn test_create_schema_sql() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let sql = strategy.create_schema_sql(&TenantId::new("acme"));
assert_eq!(sql, "CREATE SCHEMA IF NOT EXISTS \"tenant_acme\"");
}
#[test]
fn test_create_schema_sql_with_template() {
let config = SchemaPerTenantConfig::new().with_template("tenant_template");
let strategy = SchemaPerTenantStrategy::new(config).unwrap();
let sql = strategy.create_schema_sql(&TenantId::new("acme"));
assert!(sql.contains("TEMPLATE"));
assert!(sql.contains("tenant_template"));
}
#[test]
fn test_drop_schema_sql() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let sql = strategy.drop_schema_sql(&TenantId::new("acme"), false);
assert_eq!(sql, "DROP SCHEMA IF EXISTS \"tenant_acme\"");
let sql_cascade = strategy.drop_schema_sql(&TenantId::new("acme"), true);
assert_eq!(sql_cascade, "DROP SCHEMA IF EXISTS \"tenant_acme\" CASCADE");
}
#[test]
fn test_schema_exists_sql() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let sql = strategy.schema_exists_sql(&TenantId::new("acme"));
assert!(sql.contains("information_schema.schemata"));
assert!(sql.contains("tenant_acme"));
}
#[test]
fn test_list_tenant_schemas_sql() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let sql = strategy.list_tenant_schemas_sql();
assert!(sql.contains("LIKE 'tenant_%'"));
}
#[test]
fn test_system_tenant_resolution() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let resolution = strategy.system_tenant();
match resolution {
TenantResolution::Schema { schema_name } => {
assert_eq!(schema_name, "shared");
}
_ => panic!("expected Schema resolution"),
}
}
#[test]
fn test_schema_manager_create_shared() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let manager = SchemaManager::new(&strategy);
let ddl = manager.create_shared_schema_ddl();
assert!(ddl.contains("CREATE SCHEMA IF NOT EXISTS"));
assert!(ddl.contains("shared"));
}
#[test]
fn test_tenant_validation_valid() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
assert!(strategy.validate(&TenantId::new("acme")).is_ok());
assert!(strategy.validate(&TenantId::new("acme-corp")).is_ok());
}
#[test]
fn test_escape_identifier() {
let strategy = SchemaPerTenantStrategy::new(SchemaPerTenantConfig::default()).unwrap();
let escaped = strategy.escape_identifier("test\"schema");
assert_eq!(escaped, "\"test\"\"schema\"");
}
}