use crate::connection::Connection;
use crate::error::SqliteError;
#[derive(Debug, Clone)]
pub struct Migration {
pub version: u32,
pub description: String,
pub sql: String,
}
impl Migration {
pub fn new(version: u32, description: impl Into<String>, sql: impl Into<String>) -> Self {
Self {
version,
description: description.into(),
sql: sql.into(),
}
}
}
#[derive(Debug, Default)]
pub struct MigrationReport {
pub applied: usize,
pub skipped: usize,
pub current_version: u32,
}
pub struct TableBuilder {
name: String,
columns: Vec<String>,
}
impl TableBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
columns: Vec::new(),
}
}
pub fn id(mut self) -> Self {
self.columns
.push("id INTEGER PRIMARY KEY AUTOINCREMENT".to_string());
self
}
pub fn integer(mut self, name: &str) -> Self {
self.columns.push(format!("{} INTEGER", name));
self
}
pub fn integer_not_null(mut self, name: &str) -> Self {
self.columns.push(format!("{} INTEGER NOT NULL", name));
self
}
pub fn integer_default(mut self, name: &str, default: i64) -> Self {
self.columns
.push(format!("{} INTEGER DEFAULT {}", name, default));
self
}
pub fn text(mut self, name: &str) -> Self {
self.columns.push(format!("{} TEXT", name));
self
}
pub fn text_not_null(mut self, name: &str) -> Self {
self.columns.push(format!("{} TEXT NOT NULL", name));
self
}
pub fn text_default(mut self, name: &str, default: &str) -> Self {
self.columns
.push(format!("{} TEXT DEFAULT '{}'", name, default));
self
}
pub fn real(mut self, name: &str) -> Self {
self.columns.push(format!("{} REAL", name));
self
}
pub fn real_not_null(mut self, name: &str) -> Self {
self.columns.push(format!("{} REAL NOT NULL", name));
self
}
pub fn blob(mut self, name: &str) -> Self {
self.columns.push(format!("{} BLOB", name));
self
}
pub fn boolean(mut self, name: &str) -> Self {
self.columns.push(format!("{} INTEGER", name));
self
}
pub fn boolean_default(mut self, name: &str, default: bool) -> Self {
self.columns
.push(format!("{} INTEGER DEFAULT {}", name, if default { 1 } else { 0 }));
self
}
pub fn created_at(mut self) -> Self {
self.columns
.push("created_at TEXT DEFAULT (datetime('now'))".to_string());
self
}
pub fn updated_at(mut self) -> Self {
self.columns
.push("updated_at TEXT DEFAULT (datetime('now'))".to_string());
self
}
pub fn timestamps(self) -> Self {
self.created_at().updated_at()
}
pub fn foreign_key(mut self, column: &str, ref_table: &str, ref_column: &str) -> Self {
self.columns.push(format!(
"FOREIGN KEY ({}) REFERENCES {}({})",
column, ref_table, ref_column
));
self
}
pub fn unique(mut self, columns: &[&str]) -> Self {
self.columns
.push(format!("UNIQUE ({})", columns.join(", ")));
self
}
pub fn column(mut self, definition: &str) -> Self {
self.columns.push(definition.to_string());
self
}
pub fn build(&self) -> String {
format!(
"CREATE TABLE IF NOT EXISTS {} (\n {}\n)",
self.name,
self.columns.join(",\n ")
)
}
}
pub struct SchemaBuilder {
statements: Vec<String>,
}
impl SchemaBuilder {
pub fn new() -> Self {
Self {
statements: Vec::new(),
}
}
pub fn create_table<F>(&mut self, name: &str, f: F) -> Result<(), SqliteError>
where
F: FnOnce(TableBuilder) -> TableBuilder,
{
let builder = TableBuilder::new(name);
let builder = f(builder);
self.statements.push(builder.build());
Ok(())
}
pub fn create_index(&mut self, name: &str, table: &str, columns: &[&str]) -> Result<(), SqliteError> {
self.statements.push(format!(
"CREATE INDEX IF NOT EXISTS {} ON {} ({})",
name,
table,
columns.join(", ")
));
Ok(())
}
pub fn create_unique_index(
&mut self,
name: &str,
table: &str,
columns: &[&str],
) -> Result<(), SqliteError> {
self.statements.push(format!(
"CREATE UNIQUE INDEX IF NOT EXISTS {} ON {} ({})",
name,
table,
columns.join(", ")
));
Ok(())
}
pub fn drop_table(&mut self, name: &str) -> Result<(), SqliteError> {
self.statements.push(format!("DROP TABLE IF EXISTS {}", name));
Ok(())
}
pub fn drop_index(&mut self, name: &str) -> Result<(), SqliteError> {
self.statements.push(format!("DROP INDEX IF EXISTS {}", name));
Ok(())
}
pub fn raw(&mut self, sql: &str) -> Result<(), SqliteError> {
self.statements.push(sql.to_string());
Ok(())
}
pub fn build(&self) -> String {
self.statements.join(";\n")
}
}
impl Default for SchemaBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct MigrationBuilder {
migrations: Vec<Migration>,
}
impl MigrationBuilder {
pub fn new() -> Self {
Self {
migrations: Vec::new(),
}
}
pub fn version<F>(&mut self, version: u32, description: &str, f: F)
where
F: FnOnce(&mut SchemaBuilder) -> Result<(), SqliteError>,
{
let mut schema = SchemaBuilder::new();
if f(&mut schema).is_ok() {
self.migrations.push(Migration::new(version, description, schema.build()));
}
}
pub fn build(self) -> Vec<Migration> {
self.migrations
}
}
impl Default for MigrationBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct Migrator<'a> {
conn: &'a Connection,
}
impl<'a> Migrator<'a> {
pub fn new(conn: &'a Connection) -> Self {
Self { conn }
}
fn ensure_migration_table(&self) -> Result<(), SqliteError> {
self.conn.execute_batch(
"CREATE TABLE IF NOT EXISTS _migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at TEXT DEFAULT (datetime('now'))
)",
)
}
pub fn current_version(&self) -> Result<u32, SqliteError> {
self.ensure_migration_table()?;
let row = self
.conn
.query_row("SELECT MAX(version) as v FROM _migrations", &[])?;
Ok(row.and_then(|r| r.get_i64("v")).unwrap_or(0) as u32)
}
pub fn is_applied(&self, version: u32) -> Result<bool, SqliteError> {
self.ensure_migration_table()?;
let row = self.conn.query_row(
"SELECT 1 FROM _migrations WHERE version = ?",
&[version.into()],
)?;
Ok(row.is_some())
}
fn apply_migration(&self, migration: &Migration) -> Result<(), SqliteError> {
self.conn
.execute_batch(&migration.sql)
.map_err(|e| SqliteError::MigrationFailed(format!("v{}: {}", migration.version, e)))?;
self.conn.execute(
"INSERT INTO _migrations (version, description) VALUES (?, ?)",
&[migration.version.into(), migration.description.clone().into()],
)?;
Ok(())
}
pub fn migrate(&self, migrations: &[Migration]) -> Result<MigrationReport, SqliteError> {
self.ensure_migration_table()?;
let mut report = MigrationReport::default();
let mut sorted: Vec<_> = migrations.iter().collect();
sorted.sort_by_key(|m| m.version);
for migration in sorted {
if self.is_applied(migration.version)? {
report.skipped += 1;
} else {
self.apply_migration(migration)?;
report.applied += 1;
}
}
report.current_version = self.current_version()?;
Ok(report)
}
pub fn migrate_with<F>(&self, f: F) -> Result<MigrationReport, SqliteError>
where
F: FnOnce(&mut MigrationBuilder),
{
let mut builder = MigrationBuilder::new();
f(&mut builder);
self.migrate(&builder.build())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_table_builder() {
let sql = TableBuilder::new("users")
.id()
.text_not_null("name")
.integer("age")
.timestamps()
.build();
assert!(sql.contains("CREATE TABLE"));
assert!(sql.contains("id INTEGER PRIMARY KEY"));
assert!(sql.contains("name TEXT NOT NULL"));
assert!(sql.contains("created_at"));
}
#[test]
fn test_migration() {
let conn = Connection::open_in_memory().unwrap();
let migrator = Migrator::new(&conn);
let report = migrator
.migrate_with(|m| {
m.version(1, "创建用户表", |s| {
s.create_table("users", |t| t.id().text_not_null("name").timestamps())
});
m.version(2, "添加索引", |s| s.create_index("idx_users_name", "users", &["name"]));
})
.unwrap();
assert_eq!(report.applied, 2);
assert_eq!(report.current_version, 2);
let report2 = migrator
.migrate_with(|m| {
m.version(1, "创建用户表", |s| {
s.create_table("users", |t| t.id().text_not_null("name").timestamps())
});
})
.unwrap();
assert_eq!(report2.applied, 0);
assert_eq!(report2.skipped, 1);
}
#[test]
fn test_schema_builder() {
let mut schema = SchemaBuilder::new();
schema
.create_table("posts", |t| {
t.id()
.text_not_null("title")
.text("content")
.integer_not_null("user_id")
.foreign_key("user_id", "users", "id")
})
.unwrap();
schema
.create_index("idx_posts_user", "posts", &["user_id"])
.unwrap();
let sql = schema.build();
assert!(sql.contains("CREATE TABLE"));
assert!(sql.contains("FOREIGN KEY"));
assert!(sql.contains("CREATE INDEX"));
}
}