use regex::Regex;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::LazyLock;
#[derive(Debug, Clone, PartialEq)]
pub struct SchemaColumn {
pub name: String,
pub col_type: String,
pub nullable: bool,
pub has_default: bool,
pub is_primary_key: bool,
}
#[derive(Debug, Clone, Default)]
pub struct SchemaTable {
pub columns: Vec<SchemaColumn>,
}
impl SchemaTable {
#[cfg(test)]
pub fn column_names(&self) -> Vec<&str> {
self.columns.iter().map(|c| c.name.as_str()).collect()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SpecColumn {
pub name: String,
pub col_type: String,
}
static CREATE_TABLE_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?i)CREATE\s+(?:VIRTUAL\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)\s*\(").unwrap()
});
static ALTER_ADD_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?i)ALTER\s+TABLE\s+(\w+)\s+ADD\s+(?:COLUMN\s+)?(\w+)\s+(\w+)").unwrap()
});
static DROP_TABLE_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?i)DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(\w+)").unwrap());
static ALTER_DROP_COL_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?i)ALTER\s+TABLE\s+(\w+)\s+DROP\s+(?:COLUMN\s+)?(\w+)").unwrap()
});
static ALTER_RENAME_TABLE_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?i)ALTER\s+TABLE\s+(\w+)\s+RENAME\s+TO\s+(\w+)").unwrap());
static ALTER_RENAME_COL_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?i)ALTER\s+TABLE\s+(\w+)\s+RENAME\s+(?:COLUMN\s+)?(\w+)\s+TO\s+(\w+)").unwrap()
});
const SQL_EXTENSIONS: &[&str] = &[
"sql", "ts", "js", "mjs", "cjs", "swift", "kt", "kts", "java", "py", "rb", "go", "rs", "cs",
"dart", "php",
];
pub fn build_schema(schema_dir: &Path) -> HashMap<String, SchemaTable> {
let mut tables: HashMap<String, SchemaTable> = HashMap::new();
if !schema_dir.exists() {
return tables;
}
let mut files: Vec<_> = fs::read_dir(schema_dir)
.into_iter()
.flatten()
.flatten()
.filter(|e| {
let ext = e
.path()
.extension()
.and_then(|x| x.to_str())
.unwrap_or("")
.to_string();
SQL_EXTENSIONS.contains(&ext.as_str())
})
.collect();
files.sort_by_key(|e| e.file_name());
for entry in &files {
let content = match fs::read_to_string(entry.path()) {
Ok(c) => c,
Err(_) => continue,
};
parse_sql_into(&content, &mut tables);
}
tables
}
fn parse_sql_into(sql: &str, tables: &mut HashMap<String, SchemaTable>) {
for cap in CREATE_TABLE_RE.captures_iter(sql) {
let table_name = cap[1].to_string();
let start = cap.get(0).unwrap().end();
if let Some(body) = extract_paren_body(sql, start) {
let columns = parse_column_defs(&body);
let entry = tables.entry(table_name).or_default();
entry.columns = columns;
}
}
for cap in ALTER_ADD_RE.captures_iter(sql) {
let table_name = cap[1].to_string();
let col_name = cap[2].to_string();
let col_type = cap[3].to_uppercase();
let full_match_start = cap.get(0).unwrap().start();
let rest = &sql[full_match_start..];
let stmt_end = rest.find(';').unwrap_or(rest.len());
let full_stmt = &rest[..stmt_end].to_uppercase();
let nullable = !full_stmt.contains("NOT NULL");
let has_default = full_stmt.contains("DEFAULT");
let is_primary_key = full_stmt.contains("PRIMARY KEY");
let entry = tables.entry(table_name).or_default();
if !entry.columns.iter().any(|c| c.name == col_name) {
entry.columns.push(SchemaColumn {
name: col_name,
col_type,
nullable,
has_default,
is_primary_key,
});
}
}
for cap in DROP_TABLE_RE.captures_iter(sql) {
let table_name = cap[1].to_string();
tables.remove(&table_name);
}
for cap in ALTER_DROP_COL_RE.captures_iter(sql) {
let table_name = cap[1].to_string();
let col_name = cap[2].to_string();
if let Some(table) = tables.get_mut(&table_name) {
table.columns.retain(|c| c.name != col_name);
}
}
for cap in ALTER_RENAME_TABLE_RE.captures_iter(sql) {
let old_name = cap[1].to_string();
let new_name = cap[2].to_string();
if let Some(table) = tables.remove(&old_name) {
tables.insert(new_name, table);
}
}
for cap in ALTER_RENAME_COL_RE.captures_iter(sql) {
let table_name = cap[1].to_string();
let old_col = cap[2].to_string();
let new_col = cap[3].to_string();
if let Some(table) = tables.get_mut(&table_name)
&& let Some(col) = table.columns.iter_mut().find(|c| c.name == old_col)
{
col.name = new_col;
}
}
}
fn extract_paren_body(sql: &str, start: usize) -> Option<String> {
let bytes = sql.as_bytes();
let mut depth = 1;
let mut i = start;
while i < bytes.len() && depth > 0 {
match bytes[i] {
b'(' => depth += 1,
b')' => depth -= 1,
b'\'' => {
i += 1;
while i < bytes.len() {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2; continue;
}
break;
}
i += 1;
}
}
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
}
_ => {}
}
i += 1;
}
if depth == 0 {
Some(sql[start..i - 1].to_string())
} else {
None
}
}
fn parse_column_defs(body: &str) -> Vec<SchemaColumn> {
let mut columns = Vec::new();
let parts = split_top_level(body, ',');
for part in &parts {
let trimmed = part.trim();
if trimmed.is_empty() {
continue;
}
let upper = trimmed.to_uppercase();
if upper.starts_with("PRIMARY KEY")
|| upper.starts_with("UNIQUE")
|| upper.starts_with("CHECK")
|| upper.starts_with("FOREIGN KEY")
|| upper.starts_with("CONSTRAINT")
{
continue;
}
let tokens: Vec<&str> = trimmed.split_whitespace().collect();
if tokens.len() < 2 {
continue;
}
let col_name = tokens[0].to_string();
if is_sql_keyword(&col_name) {
continue;
}
let col_type = tokens[1].to_uppercase();
let rest_upper = upper.clone();
let nullable = !rest_upper.contains("NOT NULL");
let has_default = rest_upper.contains("DEFAULT");
let is_primary_key = rest_upper.contains("PRIMARY KEY");
columns.push(SchemaColumn {
name: col_name,
col_type,
nullable,
has_default,
is_primary_key,
});
}
columns
}
fn split_top_level(s: &str, delim: char) -> Vec<String> {
let mut parts = Vec::new();
let mut current = String::new();
let mut depth = 0;
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
let ch = bytes[i] as char;
match ch {
'(' => {
depth += 1;
current.push(ch);
}
')' => {
depth -= 1;
current.push(ch);
}
'\'' => {
current.push(ch);
i += 1;
while i < bytes.len() {
let c = bytes[i] as char;
current.push(c);
if c == '\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
current.push('\'');
i += 2;
continue;
}
break;
}
i += 1;
}
}
c if c == delim && depth == 0 => {
parts.push(std::mem::take(&mut current));
}
_ => current.push(ch),
}
i += 1;
}
if !current.trim().is_empty() {
parts.push(current);
}
parts
}
fn is_sql_keyword(s: &str) -> bool {
matches!(
s.to_uppercase().as_str(),
"PRIMARY" | "UNIQUE" | "CHECK" | "FOREIGN" | "CONSTRAINT" | "INDEX" | "CREATE" | "TABLE"
)
}
static SCHEMA_HEADER_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?m)^###\s+Schema(?::\s*(\w+))?\s*$").unwrap()
});
static SCHEMA_TABLE_HEADER_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?m)^####\s+`?(\w+)`?\s*$").unwrap()
});
static COLUMN_ROW_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"^\|\s*`(\w+)`\s*\|\s*([^|]+?)\s*\|").unwrap()
});
pub fn parse_spec_schema(body: &str) -> HashMap<String, Vec<SpecColumn>> {
let mut result: HashMap<String, Vec<SpecColumn>> = HashMap::new();
for schema_cap in SCHEMA_HEADER_RE.captures_iter(body) {
let match_start = schema_cap.get(0).unwrap().start();
let inline_table = schema_cap.get(1).map(|m| m.as_str().to_string());
let after_header = match body[match_start..].find('\n') {
Some(pos) => match_start + pos + 1,
None => continue,
};
let section_end = {
let rest = &body[after_header..];
let mut end = rest.len();
let mut pos = 0;
while pos < rest.len() {
if let Some(nl) = rest[pos..].find('\n') {
let line_start = pos + nl + 1;
if line_start >= rest.len() {
break;
}
let after_nl = &rest[line_start..];
if (after_nl.starts_with("## ") || after_nl.starts_with("### "))
&& !after_nl.starts_with("#### ")
{
end = line_start;
break;
}
pos = line_start;
} else {
break;
}
}
after_header + end
};
let section = &body[after_header..section_end];
if let Some(table_name) = inline_table {
let columns = extract_columns_from_section(section);
if !columns.is_empty() {
result.insert(table_name, columns);
}
} else {
let mut current_table: Option<String> = None;
let mut current_columns: Vec<SpecColumn> = Vec::new();
for line in section.lines() {
if let Some(cap) = SCHEMA_TABLE_HEADER_RE.captures(line) {
if let Some(name) = current_table.take()
&& !current_columns.is_empty()
{
result.insert(name, std::mem::take(&mut current_columns));
}
current_table = Some(cap[1].to_string());
current_columns.clear();
} else if current_table.is_some() {
if let Some(cap) = COLUMN_ROW_RE.captures(line) {
let name = cap[1].to_string();
let col_type = cap[2].trim().to_uppercase();
if !is_table_header_word(&name) {
current_columns.push(SpecColumn { name, col_type });
}
}
} else {
if let Some(cap) = COLUMN_ROW_RE.captures(line) {
let name = cap[1].to_string();
let col_type = cap[2].trim().to_uppercase();
if name.to_lowercase() != "column" && name.to_lowercase() != "name" {
current_columns.push(SpecColumn { name, col_type });
}
}
}
}
if let Some(name) = current_table
&& !current_columns.is_empty()
{
result.insert(name, current_columns);
}
}
}
result
}
fn extract_columns_from_section(section: &str) -> Vec<SpecColumn> {
let mut columns = Vec::new();
for line in section.lines() {
if let Some(cap) = COLUMN_ROW_RE.captures(line) {
let name = cap[1].to_string();
let col_type = cap[2].trim().to_uppercase();
if !is_table_header_word(&name) {
columns.push(SpecColumn { name, col_type });
}
}
}
columns
}
fn is_table_header_word(name: &str) -> bool {
name.eq_ignore_ascii_case("column")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_create_table() {
let sql = r#"
CREATE TABLE messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
sender TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
read INTEGER DEFAULT 0
);
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("messages").unwrap();
assert_eq!(t.columns.len(), 5);
assert_eq!(t.columns[0].name, "id");
assert_eq!(t.columns[0].col_type, "INTEGER");
assert!(t.columns[0].is_primary_key);
assert_eq!(t.columns[1].name, "content");
assert_eq!(t.columns[1].col_type, "TEXT");
assert!(!t.columns[1].nullable);
assert_eq!(t.columns[3].name, "created_at");
assert!(t.columns[3].has_default);
assert_eq!(t.columns[4].name, "read");
assert!(t.columns[4].nullable);
assert!(t.columns[4].has_default);
}
#[test]
fn test_parse_create_table_if_not_exists() {
let sql = "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT NOT NULL);";
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("users").unwrap();
assert_eq!(t.columns.len(), 2);
assert_eq!(t.columns[0].name, "id");
assert_eq!(t.columns[1].name, "name");
}
#[test]
fn test_parse_create_virtual_table() {
let sql = "CREATE VIRTUAL TABLE search_idx USING fts5(content, sender);";
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
assert!(!tables.contains_key("search_idx"));
}
#[test]
fn test_parse_alter_table_add_column() {
let sql = r#"
CREATE TABLE tasks (id INTEGER PRIMARY KEY, title TEXT NOT NULL);
ALTER TABLE tasks ADD COLUMN status TEXT NOT NULL DEFAULT 'pending';
ALTER TABLE tasks ADD COLUMN priority INTEGER DEFAULT 0;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("tasks").unwrap();
assert_eq!(t.columns.len(), 4);
assert_eq!(t.columns[2].name, "status");
assert_eq!(t.columns[2].col_type, "TEXT");
assert!(!t.columns[2].nullable);
assert!(t.columns[2].has_default);
assert_eq!(t.columns[3].name, "priority");
assert!(t.columns[3].nullable);
}
#[test]
fn test_alter_idempotent() {
let sql = r#"
CREATE TABLE t (id INTEGER PRIMARY KEY);
ALTER TABLE t ADD COLUMN name TEXT;
ALTER TABLE t ADD COLUMN name TEXT;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
assert_eq!(tables.get("t").unwrap().columns.len(), 2);
}
#[test]
fn test_table_constraints_skipped() {
let sql = r#"
CREATE TABLE edges (
source_id INTEGER NOT NULL,
target_id INTEGER NOT NULL,
weight REAL DEFAULT 1.0,
PRIMARY KEY (source_id, target_id),
FOREIGN KEY (source_id) REFERENCES nodes(id),
UNIQUE (source_id, target_id, weight),
CHECK (weight > 0)
);
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("edges").unwrap();
assert_eq!(t.columns.len(), 3);
assert_eq!(t.column_names(), vec!["source_id", "target_id", "weight"]);
}
#[test]
fn test_string_literal_in_default() {
let sql = "CREATE TABLE t (status TEXT NOT NULL DEFAULT 'it''s pending');";
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("t").unwrap();
assert_eq!(t.columns.len(), 1);
assert!(t.columns[0].has_default);
}
#[test]
fn test_parse_spec_schema_inline() {
let body = r#"## Purpose
Something
### Schema: messages
| Column | Type | Constraints |
|--------|------|-------------|
| `id` | INTEGER | PRIMARY KEY |
| `content` | TEXT | NOT NULL |
| `created_at` | TEXT | DEFAULT |
## Invariants
"#;
let schema = parse_spec_schema(body);
assert_eq!(schema.len(), 1);
let cols = schema.get("messages").unwrap();
assert_eq!(cols.len(), 3);
assert_eq!(cols[0].name, "id");
assert_eq!(cols[0].col_type, "INTEGER");
assert_eq!(cols[1].name, "content");
assert_eq!(cols[2].name, "created_at");
}
#[test]
fn test_parse_spec_schema_multi_table() {
let body = r#"## Purpose
Something
### Schema
#### `messages`
| Column | Type | Description |
|--------|------|-------------|
| `id` | INTEGER | Row ID |
| `body` | TEXT | Message body |
#### `users`
| Column | Type | Description |
|--------|------|-------------|
| `id` | INTEGER | Row ID |
| `name` | TEXT | Username |
| `email` | TEXT | Email addr |
## Invariants
"#;
let schema = parse_spec_schema(body);
assert_eq!(schema.len(), 2);
assert_eq!(schema.get("messages").unwrap().len(), 2);
assert_eq!(schema.get("users").unwrap().len(), 3);
}
#[test]
fn test_parse_spec_schema_no_section() {
let body = "## Purpose\nSomething\n## Public API\nStuff\n";
let schema = parse_spec_schema(body);
assert!(schema.is_empty());
}
#[test]
fn test_build_schema_nonexistent_dir() {
let tables = build_schema(Path::new("/nonexistent/path"));
assert!(tables.is_empty());
}
#[test]
fn test_build_schema_migration_ordering() {
let tmp = tempfile::tempdir().unwrap();
let dir = tmp.path();
fs::write(
dir.join("001_create.sql"),
"CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
)
.unwrap();
fs::write(
dir.join("002_add_col.sql"),
"ALTER TABLE items ADD COLUMN price REAL DEFAULT 0;",
)
.unwrap();
let tables = build_schema(dir);
let t = tables.get("items").unwrap();
assert_eq!(t.columns.len(), 3);
assert_eq!(t.columns[0].name, "id");
assert_eq!(t.columns[1].name, "name");
assert_eq!(t.columns[2].name, "price");
assert_eq!(t.columns[2].col_type, "REAL");
}
#[test]
fn test_drop_table() {
let sql = r#"
CREATE TABLE temp_data (id INTEGER PRIMARY KEY, value TEXT);
CREATE TABLE keep_me (id INTEGER PRIMARY KEY);
DROP TABLE temp_data;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
assert!(!tables.contains_key("temp_data"));
assert!(tables.contains_key("keep_me"));
}
#[test]
fn test_drop_table_if_exists() {
let sql = r#"
CREATE TABLE things (id INTEGER PRIMARY KEY, name TEXT);
DROP TABLE IF EXISTS things;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
assert!(!tables.contains_key("things"));
}
#[test]
fn test_drop_column() {
let sql = r#"
CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, legacy TEXT);
ALTER TABLE users DROP COLUMN legacy;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("users").unwrap();
assert_eq!(t.columns.len(), 2);
assert_eq!(t.column_names(), vec!["id", "name"]);
}
#[test]
fn test_rename_table() {
let sql = r#"
CREATE TABLE old_name (id INTEGER PRIMARY KEY, data TEXT);
ALTER TABLE old_name RENAME TO new_name;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
assert!(!tables.contains_key("old_name"));
let t = tables.get("new_name").unwrap();
assert_eq!(t.columns.len(), 2);
}
#[test]
fn test_rename_column() {
let sql = r#"
CREATE TABLE items (id INTEGER PRIMARY KEY, old_col TEXT NOT NULL);
ALTER TABLE items RENAME COLUMN old_col TO new_col;
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
let t = tables.get("items").unwrap();
assert_eq!(t.columns.len(), 2);
assert_eq!(t.columns[1].name, "new_col");
assert_eq!(t.columns[1].col_type, "TEXT");
}
#[test]
fn test_sql_extensions_list() {
assert!(SQL_EXTENSIONS.contains(&"sql"));
assert!(SQL_EXTENSIONS.contains(&"ts"));
assert!(SQL_EXTENSIONS.contains(&"swift"));
assert!(SQL_EXTENSIONS.contains(&"kt"));
assert!(SQL_EXTENSIONS.contains(&"java"));
assert!(SQL_EXTENSIONS.contains(&"py"));
assert!(SQL_EXTENSIONS.contains(&"rb"));
assert!(SQL_EXTENSIONS.contains(&"go"));
assert!(SQL_EXTENSIONS.contains(&"rs"));
}
#[test]
fn test_multiple_tables_in_one_file() {
let sql = r#"
CREATE TABLE a (id INTEGER PRIMARY KEY);
CREATE TABLE b (id INTEGER PRIMARY KEY, ref_a INTEGER);
"#;
let mut tables = HashMap::new();
parse_sql_into(sql, &mut tables);
assert!(tables.contains_key("a"));
assert!(tables.contains_key("b"));
assert_eq!(tables.get("b").unwrap().columns.len(), 2);
}
}