use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum GlobalTableMode {
None,
#[default]
Lookups,
All,
}
impl std::str::FromStr for GlobalTableMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"none" => Ok(GlobalTableMode::None),
"lookups" => Ok(GlobalTableMode::Lookups),
"all" => Ok(GlobalTableMode::All),
_ => Err(format!(
"Unknown global mode: {}. Valid options: none, lookups, all",
s
)),
}
}
}
impl std::fmt::Display for GlobalTableMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GlobalTableMode::None => write!(f, "none"),
GlobalTableMode::Lookups => write!(f, "lookups"),
GlobalTableMode::All => write!(f, "all"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ShardTableClassification {
TenantRoot,
TenantDependent,
Junction,
Lookup,
System,
#[default]
Unknown,
}
impl std::fmt::Display for ShardTableClassification {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShardTableClassification::TenantRoot => write!(f, "tenant-root"),
ShardTableClassification::TenantDependent => write!(f, "tenant-dependent"),
ShardTableClassification::Junction => write!(f, "junction"),
ShardTableClassification::Lookup => write!(f, "lookup"),
ShardTableClassification::System => write!(f, "system"),
ShardTableClassification::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct TableOverride {
pub role: Option<ShardTableClassification>,
pub include: Option<bool>,
pub self_fk: Option<String>,
pub skip: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct TenantConfig {
pub column: Option<String>,
#[serde(default)]
pub root_tables: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct ShardYamlConfig {
pub tenant: TenantConfig,
#[serde(default)]
pub tables: HashMap<String, TableOverride>,
pub include_global: Option<GlobalTableMode>,
}
impl ShardYamlConfig {
pub fn load(path: &Path) -> anyhow::Result<Self> {
let content = fs::read_to_string(path)?;
let config: ShardYamlConfig = serde_yml::from_str(&content)?;
Ok(config)
}
pub fn get_table_override(&self, table_name: &str) -> Option<&TableOverride> {
self.tables.get(table_name).or_else(|| {
let lower = table_name.to_lowercase();
self.tables
.iter()
.find(|(k, _)| k.to_lowercase() == lower)
.map(|(_, v)| v)
})
}
pub fn get_classification(&self, table_name: &str) -> Option<ShardTableClassification> {
self.get_table_override(table_name).and_then(|o| o.role)
}
pub fn should_skip(&self, table_name: &str) -> bool {
self.get_table_override(table_name)
.map(|o| o.skip)
.unwrap_or(false)
}
#[allow(dead_code)]
pub fn get_self_fk(&self, table_name: &str) -> Option<&str> {
self.get_table_override(table_name)
.and_then(|o| o.self_fk.as_deref())
}
}
pub struct DefaultShardClassifier;
impl DefaultShardClassifier {
pub const TENANT_COLUMNS: &'static [&'static str] = &[
"company_id",
"tenant_id",
"organization_id",
"org_id",
"account_id",
"team_id",
"workspace_id",
];
pub const SYSTEM_PATTERNS: &'static [&'static str] = &[
"migrations",
"failed_jobs",
"job_batches",
"jobs",
"cache",
"cache_locks",
"sessions",
"password_reset_tokens",
"personal_access_tokens",
"telescope_entries",
"telescope_entries_tags",
"telescope_monitoring",
"pulse_",
"horizon_",
];
pub const LOOKUP_PATTERNS: &'static [&'static str] = &[
"countries",
"states",
"provinces",
"cities",
"currencies",
"languages",
"timezones",
"permissions",
"roles",
"settings",
];
pub fn is_system_table(table_name: &str) -> bool {
let lower = table_name.to_lowercase();
for pattern in Self::SYSTEM_PATTERNS {
if lower.starts_with(pattern) || lower == *pattern {
return true;
}
}
false
}
pub fn is_lookup_table(table_name: &str) -> bool {
let lower = table_name.to_lowercase();
for pattern in Self::LOOKUP_PATTERNS {
if lower == *pattern {
return true;
}
}
false
}
pub fn is_junction_table_by_name(table_name: &str) -> bool {
let lower = table_name.to_lowercase();
lower.contains("_has_")
|| lower.ends_with("_pivot")
|| lower.ends_with("_link")
|| lower.ends_with("_map")
}
}