use serde::{Deserialize, Serialize};
use crate::tenant::TenantId;
use super::{TenantResolution, TenantResolver, TenantValidationError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SharedSchemaConfig {
#[serde(default)]
pub use_row_level_security: bool,
#[serde(default = "default_tenant_column")]
pub tenant_column: String,
#[serde(default = "default_true")]
pub index_tenant_first: bool,
#[serde(default = "default_max_tenant_id_length")]
pub max_tenant_id_length: usize,
#[serde(default = "default_tenant_id_pattern")]
pub tenant_id_pattern: String,
#[serde(default)]
pub hash_long_ids: bool,
}
fn default_tenant_column() -> String {
"tenant_id".to_string()
}
fn default_true() -> bool {
true
}
fn default_max_tenant_id_length() -> usize {
64
}
fn default_tenant_id_pattern() -> String {
r"^[a-zA-Z0-9_\-/]+$".to_string()
}
impl Default for SharedSchemaConfig {
fn default() -> Self {
Self {
use_row_level_security: false,
tenant_column: default_tenant_column(),
index_tenant_first: true,
max_tenant_id_length: default_max_tenant_id_length(),
tenant_id_pattern: default_tenant_id_pattern(),
hash_long_ids: false,
}
}
}
impl SharedSchemaConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_rls(mut self) -> Self {
self.use_row_level_security = true;
self
}
pub fn with_tenant_column(mut self, column: impl Into<String>) -> Self {
self.tenant_column = column.into();
self
}
}
#[derive(Debug, Clone)]
pub struct SharedSchemaStrategy {
config: SharedSchemaConfig,
tenant_pattern: regex::Regex,
}
impl SharedSchemaStrategy {
pub fn new(config: SharedSchemaConfig) -> Result<Self, regex::Error> {
let tenant_pattern = regex::Regex::new(&config.tenant_id_pattern)?;
Ok(Self {
config,
tenant_pattern,
})
}
pub fn config(&self) -> &SharedSchemaConfig {
&self.config
}
pub fn tenant_column(&self) -> &str {
&self.config.tenant_column
}
pub fn uses_rls(&self) -> bool {
self.config.use_row_level_security
}
pub fn set_tenant_sql(&self, tenant_id: &TenantId) -> String {
format!(
"SET LOCAL app.current_tenant = '{}'",
self.escape_sql_string(tenant_id.as_str())
)
}
pub fn clear_tenant_sql(&self) -> String {
"RESET app.current_tenant".to_string()
}
pub fn tenant_filter_sql(&self, table_alias: Option<&str>) -> String {
match table_alias {
Some(alias) => format!("{}.{} = $tenant_id", alias, self.config.tenant_column),
None => format!("{} = $tenant_id", self.config.tenant_column),
}
}
fn escape_sql_string(&self, s: &str) -> String {
s.replace('\'', "''")
}
fn normalize_tenant_id(&self, tenant_id: &TenantId) -> String {
let id = tenant_id.as_str();
if self.config.hash_long_ids && id.len() > self.config.max_tenant_id_length {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
id.hash(&mut hasher);
format!("h_{:016x}", hasher.finish())
} else {
id.to_string()
}
}
}
impl TenantResolver for SharedSchemaStrategy {
fn resolve(&self, tenant_id: &TenantId) -> TenantResolution {
TenantResolution::SharedSchema {
tenant_id: self.normalize_tenant_id(tenant_id),
}
}
fn validate(&self, tenant_id: &TenantId) -> Result<(), TenantValidationError> {
let id = tenant_id.as_str();
if !self.config.hash_long_ids && id.len() > self.config.max_tenant_id_length {
return Err(TenantValidationError {
tenant_id: id.to_string(),
reason: format!(
"tenant ID exceeds maximum length of {} characters",
self.config.max_tenant_id_length
),
});
}
if !self.tenant_pattern.is_match(id) {
return Err(TenantValidationError {
tenant_id: id.to_string(),
reason: format!(
"tenant ID does not match required pattern: {}",
self.config.tenant_id_pattern
),
});
}
Ok(())
}
fn system_tenant(&self) -> TenantResolution {
TenantResolution::SharedSchema {
tenant_id: crate::tenant::SYSTEM_TENANT.to_string(),
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct TenantAwareTableBuilder {
table_name: String,
tenant_column: String,
columns: Vec<ColumnDef>,
indexes: Vec<IndexDef>,
use_rls: bool,
}
#[derive(Debug)]
#[allow(dead_code)]
struct ColumnDef {
name: String,
data_type: String,
nullable: bool,
}
#[derive(Debug)]
#[allow(dead_code)]
struct IndexDef {
name: String,
columns: Vec<String>,
unique: bool,
}
#[allow(dead_code)]
impl TenantAwareTableBuilder {
pub fn new(table_name: impl Into<String>, config: &SharedSchemaConfig) -> Self {
Self {
table_name: table_name.into(),
tenant_column: config.tenant_column.clone(),
columns: Vec::new(),
indexes: Vec::new(),
use_rls: config.use_row_level_security,
}
}
pub fn column(
mut self,
name: impl Into<String>,
data_type: impl Into<String>,
nullable: bool,
) -> Self {
self.columns.push(ColumnDef {
name: name.into(),
data_type: data_type.into(),
nullable,
});
self
}
pub fn index(mut self, name: impl Into<String>, columns: Vec<&str>, unique: bool) -> Self {
self.indexes.push(IndexDef {
name: name.into(),
columns: columns.into_iter().map(String::from).collect(),
unique,
});
self
}
pub fn to_postgres_ddl(&self) -> String {
let mut ddl = String::new();
ddl.push_str(&format!(
"CREATE TABLE IF NOT EXISTS {} (\n",
self.table_name
));
ddl.push_str(&format!(
" {} VARCHAR(64) NOT NULL,\n",
self.tenant_column
));
for col in &self.columns {
let null_str = if col.nullable { "" } else { " NOT NULL" };
ddl.push_str(&format!(
" {} {}{},\n",
col.name, col.data_type, null_str
));
}
ddl.truncate(ddl.len() - 2);
ddl.push_str("\n);\n\n");
for idx in &self.indexes {
let unique_str = if idx.unique { "UNIQUE " } else { "" };
let columns: Vec<_> = std::iter::once(self.tenant_column.as_str())
.chain(idx.columns.iter().map(|s| s.as_str()))
.collect();
ddl.push_str(&format!(
"CREATE {}INDEX IF NOT EXISTS {} ON {} ({});\n",
unique_str,
idx.name,
self.table_name,
columns.join(", ")
));
}
if self.use_rls {
ddl.push_str(&format!(
"\nALTER TABLE {} ENABLE ROW LEVEL SECURITY;\n",
self.table_name
));
ddl.push_str(&format!(
"CREATE POLICY tenant_isolation ON {} USING ({} = current_setting('app.current_tenant'));\n",
self.table_name, self.tenant_column
));
}
ddl
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shared_schema_config_default() {
let config = SharedSchemaConfig::default();
assert_eq!(config.tenant_column, "tenant_id");
assert!(!config.use_row_level_security);
assert!(config.index_tenant_first);
}
#[test]
fn test_shared_schema_config_builder() {
let config = SharedSchemaConfig::new()
.with_rls()
.with_tenant_column("org_id");
assert!(config.use_row_level_security);
assert_eq!(config.tenant_column, "org_id");
}
#[test]
fn test_shared_schema_strategy_creation() {
let config = SharedSchemaConfig::default();
let strategy = SharedSchemaStrategy::new(config).unwrap();
assert_eq!(strategy.tenant_column(), "tenant_id");
}
#[test]
fn test_tenant_resolution() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
let resolution = strategy.resolve(&TenantId::new("acme"));
match resolution {
TenantResolution::SharedSchema { tenant_id } => {
assert_eq!(tenant_id, "acme");
}
_ => panic!("expected SharedSchema resolution"),
}
}
#[test]
fn test_tenant_validation_valid() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
assert!(strategy.validate(&TenantId::new("acme")).is_ok());
assert!(strategy.validate(&TenantId::new("acme/research")).is_ok());
assert!(strategy.validate(&TenantId::new("tenant_123")).is_ok());
}
#[test]
fn test_tenant_validation_invalid_pattern() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
let result = strategy.validate(&TenantId::new("tenant with spaces"));
assert!(result.is_err());
}
#[test]
fn test_tenant_validation_too_long() {
let config = SharedSchemaConfig {
max_tenant_id_length: 10,
..Default::default()
};
let strategy = SharedSchemaStrategy::new(config).unwrap();
let result = strategy.validate(&TenantId::new("this-is-a-very-long-tenant-id"));
assert!(result.is_err());
}
#[test]
fn test_set_tenant_sql() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
let sql = strategy.set_tenant_sql(&TenantId::new("acme"));
assert_eq!(sql, "SET LOCAL app.current_tenant = 'acme'");
}
#[test]
fn test_set_tenant_sql_escapes() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
let sql = strategy.set_tenant_sql(&TenantId::new("o'brien"));
assert_eq!(sql, "SET LOCAL app.current_tenant = 'o''brien'");
}
#[test]
fn test_tenant_filter_sql() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
let filter = strategy.tenant_filter_sql(None);
assert_eq!(filter, "tenant_id = $tenant_id");
let filter_aliased = strategy.tenant_filter_sql(Some("p"));
assert_eq!(filter_aliased, "p.tenant_id = $tenant_id");
}
#[test]
fn test_table_builder() {
let config = SharedSchemaConfig::default();
let ddl = TenantAwareTableBuilder::new("patient", &config)
.column("id", "VARCHAR(64)", false)
.column("family_name", "TEXT", true)
.index("idx_patient_id", vec!["id"], true)
.to_postgres_ddl();
assert!(ddl.contains("CREATE TABLE IF NOT EXISTS patient"));
assert!(ddl.contains("tenant_id VARCHAR(64) NOT NULL"));
assert!(ddl.contains("id VARCHAR(64) NOT NULL"));
assert!(ddl.contains("CREATE UNIQUE INDEX"));
assert!(ddl.contains("(tenant_id, id)"));
}
#[test]
fn test_table_builder_with_rls() {
let config = SharedSchemaConfig::new().with_rls();
let ddl = TenantAwareTableBuilder::new("patient", &config)
.column("id", "VARCHAR(64)", false)
.to_postgres_ddl();
assert!(ddl.contains("ENABLE ROW LEVEL SECURITY"));
assert!(ddl.contains("CREATE POLICY tenant_isolation"));
}
#[test]
fn test_system_tenant_resolution() {
let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
let resolution = strategy.system_tenant();
match resolution {
TenantResolution::SharedSchema { tenant_id } => {
assert_eq!(tenant_id, crate::tenant::SYSTEM_TENANT);
}
_ => panic!("expected SharedSchema resolution"),
}
}
}