use std::collections::HashMap;
use std::path::PathBuf;
struct DialectCfg {
migration_dir: &'static str,
schema_ext: &'static str,
ph_prefix: &'static str,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Dialect {
Sqlite,
Postgres,
Mysql,
}
impl Dialect {
const TABLE: &'static [DialectCfg] = &[
DialectCfg {
migration_dir: "sqlite",
schema_ext: ".sqlite.sql",
ph_prefix: "?",
},
DialectCfg {
migration_dir: "postgres",
schema_ext: ".postgres.sql",
ph_prefix: "$",
},
DialectCfg {
migration_dir: "mysql",
schema_ext: ".mysql.sql",
ph_prefix: "?",
},
];
fn cfg(self) -> &'static DialectCfg {
&Self::TABLE[self as usize]
}
pub fn from_env() -> Self {
let url = std::env::var("DATABASE_URL").unwrap_or_default();
if url.starts_with("postgres:") || url.starts_with("postgresql:") {
Dialect::Postgres
} else if url.starts_with("mysql:") {
Dialect::Mysql
} else {
Dialect::Sqlite
}
}
pub fn ph(&self, idx: usize) -> String {
match self {
Dialect::Postgres => format!("${idx}"),
Dialect::Sqlite => format!("?{idx}"),
Dialect::Mysql => "?".to_string(),
}
}
#[expect(dead_code)]
pub fn ph_unnumbered(&self) -> &'static str {
let c = self.cfg();
if c.ph_prefix == "$" {
"$?" } else {
"?"
}
}
#[expect(dead_code)]
pub fn limit_offset_clause(&self) -> &'static str {
" LIMIT ? OFFSET ?"
}
pub fn migration_dir(&self) -> &'static str {
self.cfg().migration_dir
}
pub fn schema_ext(&self) -> &'static str {
self.cfg().schema_ext
}
pub fn upsert_clause(&self, conflict_cols: &str, assignments: &str) -> String {
match self {
Dialect::Mysql => format!("ON DUPLICATE KEY UPDATE {assignments}"),
_ => format!("ON CONFLICT({conflict_cols}) DO UPDATE SET {assignments}"),
}
}
pub fn excluded_col(&self, col: &str) -> String {
match self {
Dialect::Mysql => format!("VALUES({col})"),
_ => format!("excluded.{col}"),
}
}
}
pub struct Schema {
pub tables: HashMap<String, TableSchema>,
pub dialect: Dialect,
}
pub struct TableSchema {
pub columns: Vec<ColumnSchema>,
}
pub struct ColumnSchema {
pub name: String,
#[expect(dead_code)]
pub ty: SqlType,
#[expect(dead_code)]
pub nullable: bool,
#[expect(dead_code)]
pub has_default: bool,
}
pub enum SqlType {
Integer,
Real,
Text,
Blob,
}
impl Schema {
pub fn load() -> Self {
let dialect = Dialect::from_env();
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
let base = PathBuf::from(manifest_dir);
let dir = format!("migrations/{}", dialect.migration_dir());
let ext = dialect.schema_ext();
let schema_file = format!("{}{}{}", dir, "/schema", ext);
let schema_sql = std::fs::read_to_string(base.join(&schema_file)).unwrap_or_default();
let tables = parse_schema(&schema_sql);
Schema { tables, dialect }
}
#[expect(dead_code)]
pub fn select_columns(&self, table: &str) -> Option<String> {
let ts = self.tables.get(table)?;
let mut parts: Vec<String> = Vec::new();
for col in &ts.columns {
if is_timestamp_col(&col.name) {
parts.push(format!(
r#"{} as "{}: chrono::DateTime<chrono::Utc>""#,
col.name, col.name
));
} else {
parts.push(col.name.clone());
}
}
Some(parts.join(", "))
}
pub fn column_names(&self, table: &str) -> Vec<String> {
self.tables
.get(table)
.map(|ts| ts.columns.iter().map(|c| c.name.clone()).collect())
.unwrap_or_default()
}
}
#[allow(dead_code)]
fn is_timestamp_col(name: &str) -> bool {
name == "created_at" || name == "updated_at" || name == "expires_at"
}
fn parse_schema(sql: &str) -> HashMap<String, TableSchema> {
let mut tables = HashMap::new();
for line in sql.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed.strip_prefix("CREATE TABLE")
&& let Some(name) = extract_table_name(rest)
{
tables.insert(
name,
TableSchema {
columns: Vec::new(),
},
);
}
}
let mut current_table: Option<String> = None;
let mut in_create = false;
for line in sql.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Some(rest) = trimmed.strip_prefix("CREATE TABLE") {
if let Some(name) = extract_table_name(rest) {
current_table = Some(name);
in_create = true;
if let Some(t) = tables.get_mut(current_table.as_ref().unwrap()) {
t.columns.clear();
}
}
continue;
}
if in_create && trimmed.starts_with(')') {
in_create = false;
current_table = None;
continue;
}
if let (Some(tn), true) = (¤t_table, in_create)
&& let Some(col) = parse_column_line(trimmed)
&& let Some(t) = tables.get_mut(tn)
{
t.columns.push(col);
}
}
tables
}
fn extract_table_name(rest: &str) -> Option<String> {
let rest = rest.trim();
let rest = rest.strip_prefix("IF NOT EXISTS").unwrap_or(rest).trim();
let name = rest
.split(|c: char| !c.is_alphanumeric() && c != '_')
.next()?;
if name.is_empty() {
return None;
}
Some(name.to_lowercase())
}
fn parse_column_line(line: &str) -> Option<ColumnSchema> {
let line = line.trim().trim_end_matches(',');
if line.is_empty()
|| line.starts_with("PRIMARY KEY")
|| line.starts_with("UNIQUE(")
|| line.starts_with("CHECK(")
|| line.starts_with("FOREIGN")
{
return None;
}
let mut parts = line.splitn(2, char::is_whitespace);
let name = parts.next()?.trim();
let rest = parts.next()?.trim().to_uppercase();
if name.is_empty() {
return None;
}
let ty = if rest.starts_with("INTEGER") || rest.starts_with("INT") {
SqlType::Integer
} else if rest.starts_with("REAL") || rest.starts_with("FLOAT") || rest.starts_with("DOUBLE") {
SqlType::Real
} else if rest.starts_with("BLOB") {
SqlType::Blob
} else {
SqlType::Text
};
let nullable = !rest.contains("NOT NULL");
let has_default = rest.contains("DEFAULT");
Some(ColumnSchema {
name: name.to_lowercase(),
ty,
nullable,
has_default,
})
}