use crate::{
adapter::AdapterKind,
error::{DataError, DataResult},
};
use serde::{Deserialize, Serialize};
use std::{
fs,
path::{Path, PathBuf},
time::SystemTime,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Migration {
pub id: String,
pub name: String,
pub up_sql: String,
pub down_sql: String,
pub path: PathBuf,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AppliedMigration {
pub id: String,
pub name: String,
pub applied_at_unix_ms: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct MigrationStatus {
pub applied: Vec<AppliedMigration>,
pub pending: Vec<Migration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MigrationState {
adapter: AdapterKind,
applied: Vec<AppliedMigration>,
}
pub struct MigrationEngine {
project_root: PathBuf,
adapter: AdapterKind,
}
impl MigrationEngine {
pub fn new(project_root: impl Into<PathBuf>, adapter: AdapterKind) -> DataResult<Self> {
if adapter == AdapterKind::None {
return Err(DataError::Migration(
"adapter is `none`; choose a SQL adapter before running migrations".to_string(),
));
}
if !adapter.supports_migrations() {
return Err(DataError::Migration(format!(
"adapter `{}` does not support SQL migrations",
adapter.as_str()
)));
}
Ok(Self {
project_root: project_root.into(),
adapter,
})
}
pub fn status(&self, all_migrations: &[Migration]) -> DataResult<MigrationStatus> {
let state = self.read_state()?;
let pending = all_migrations
.iter()
.filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
.cloned()
.collect::<Vec<_>>();
Ok(MigrationStatus {
applied: state.applied,
pending,
})
}
pub fn migrate(
&self,
all_migrations: &[Migration],
steps: Option<usize>,
) -> DataResult<Vec<AppliedMigration>> {
let mut state = self.read_state()?;
let mut applied_now = Vec::<AppliedMigration>::new();
let pending = all_migrations
.iter()
.filter(|migration| state.applied.iter().all(|item| item.id != migration.id))
.cloned()
.collect::<Vec<_>>();
let pending = if let Some(steps) = steps {
pending.into_iter().take(steps).collect::<Vec<_>>()
} else {
pending
};
for migration in pending {
if migration.up_sql.trim().is_empty() {
return Err(DataError::Migration(format!(
"migration {} has empty up SQL",
migration.path.display()
)));
}
let applied = AppliedMigration {
id: migration.id,
name: migration.name,
applied_at_unix_ms: now_unix_ms(),
};
state.applied.push(applied.clone());
applied_now.push(applied);
}
self.write_state(&state)?;
Ok(applied_now)
}
pub fn rollback(
&self,
all_migrations: &[Migration],
steps: usize,
) -> DataResult<Vec<AppliedMigration>> {
let mut state = self.read_state()?;
let mut rolled_back = Vec::<AppliedMigration>::new();
let steps = steps.max(1);
for _ in 0..steps {
let Some(last) = state.applied.pop() else {
break;
};
let Some(definition) = all_migrations
.iter()
.find(|migration| migration.id == last.id)
else {
return Err(DataError::Migration(format!(
"cannot rollback migration `{}` because file is missing",
last.id
)));
};
if definition.down_sql.trim().is_empty() {
return Err(DataError::Migration(format!(
"migration {} has empty down SQL",
definition.path.display()
)));
}
rolled_back.push(last);
}
self.write_state(&state)?;
Ok(rolled_back)
}
fn state_path(&self) -> PathBuf {
self.project_root
.join(".shelly")
.join("migrations")
.join(format!("{}.json", self.adapter.as_str()))
}
fn read_state(&self) -> DataResult<MigrationState> {
let state_path = self.state_path();
if !state_path.exists() {
return Ok(MigrationState {
adapter: self.adapter,
applied: Vec::new(),
});
}
let raw = fs::read_to_string(state_path)?;
let mut state: MigrationState = serde_json::from_str(&raw)?;
state.adapter = self.adapter;
Ok(state)
}
fn write_state(&self, state: &MigrationState) -> DataResult<()> {
let state_path = self.state_path();
if let Some(parent) = state_path.parent() {
fs::create_dir_all(parent)?;
}
let body = serde_json::to_string_pretty(state)?;
fs::write(state_path, format!("{body}\n"))?;
Ok(())
}
}
pub fn load_migrations(dir: &Path) -> DataResult<Vec<Migration>> {
if !dir.exists() {
return Ok(Vec::new());
}
let mut entries = fs::read_dir(dir)?
.filter_map(|entry| entry.ok())
.map(|entry| entry.path())
.filter(|path| path.extension().is_some_and(|extension| extension == "sql"))
.collect::<Vec<_>>();
entries.sort();
let mut migrations = Vec::with_capacity(entries.len());
for path in entries {
let Some(file_name) = path.file_name().and_then(|name| name.to_str()) else {
continue;
};
let Some((id, name)) = parse_file_id_name(file_name) else {
continue;
};
let source = fs::read_to_string(&path)?;
let (up_sql, down_sql) = parse_up_down(&source, &path)?;
migrations.push(Migration {
id,
name,
up_sql,
down_sql,
path,
});
}
Ok(migrations)
}
fn parse_file_id_name(file_name: &str) -> Option<(String, String)> {
let trimmed = file_name.strip_suffix(".sql")?;
let (id, name) = trimmed.split_once('_')?;
Some((id.to_string(), name.to_string()))
}
fn parse_up_down(source: &str, path: &Path) -> DataResult<(String, String)> {
let up_marker = "-- +up";
let down_marker = "-- +down";
let Some(up_start) = source.find(up_marker) else {
return Err(DataError::Migration(format!(
"migration {} missing `-- +up` marker",
path.display()
)));
};
let Some(down_start) = source.find(down_marker) else {
return Err(DataError::Migration(format!(
"migration {} missing `-- +down` marker",
path.display()
)));
};
if down_start <= up_start {
return Err(DataError::Migration(format!(
"migration {} has invalid marker order",
path.display()
)));
}
let up_sql = source[up_start + up_marker.len()..down_start]
.trim()
.to_string();
let down_sql = source[down_start + down_marker.len()..].trim().to_string();
Ok((up_sql, down_sql))
}
fn now_unix_ms() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::{load_migrations, MigrationEngine};
use crate::AdapterKind;
use std::{fs, path::PathBuf, time::SystemTime};
#[test]
fn migration_lifecycle_applies_and_rolls_back() {
let root = temp_path("shelly_data_migration");
let migrations_dir = root.join("migrations");
fs::create_dir_all(&migrations_dir).unwrap();
fs::write(
migrations_dir.join("20260505120000_create_posts.sql"),
r#"
-- +up
CREATE TABLE posts(id BIGINT PRIMARY KEY);
-- +down
DROP TABLE posts;
"#,
)
.unwrap();
let migrations = load_migrations(&migrations_dir).unwrap();
let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
let applied = engine.migrate(&migrations, None).unwrap();
assert_eq!(applied.len(), 1);
let status = engine.status(&migrations).unwrap();
assert_eq!(status.applied.len(), 1);
assert_eq!(status.pending.len(), 0);
let rolled_back = engine.rollback(&migrations, 1).unwrap();
assert_eq!(rolled_back.len(), 1);
let status = engine.status(&migrations).unwrap();
assert_eq!(status.applied.len(), 0);
fs::remove_dir_all(root).unwrap();
}
#[test]
fn migration_loader_rejects_invalid_marker_order() {
let root = temp_path("shelly_data_invalid_marker_order");
let migrations_dir = root.join("migrations");
fs::create_dir_all(&migrations_dir).unwrap();
fs::write(
migrations_dir.join("20260505130000_invalid.sql"),
r#"
-- +down
DROP TABLE posts;
-- +up
CREATE TABLE posts(id BIGINT PRIMARY KEY);
"#,
)
.unwrap();
let err = load_migrations(&migrations_dir).unwrap_err().to_string();
assert!(err.contains("invalid marker order"));
fs::remove_dir_all(root).unwrap();
}
#[test]
fn rollback_fails_when_applied_migration_file_is_missing() {
let root = temp_path("shelly_data_missing_migration_file");
let migrations_dir = root.join("migrations");
fs::create_dir_all(&migrations_dir).unwrap();
let original_path = migrations_dir.join("20260505140000_create_posts.sql");
fs::write(
&original_path,
r#"
-- +up
CREATE TABLE posts(id BIGINT PRIMARY KEY);
-- +down
DROP TABLE posts;
"#,
)
.unwrap();
let migrations = load_migrations(&migrations_dir).unwrap();
let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
engine.migrate(&migrations, None).unwrap();
fs::remove_file(&original_path).unwrap();
let now_missing = load_migrations(&migrations_dir).unwrap();
let err = engine.rollback(&now_missing, 1).unwrap_err().to_string();
assert!(err.contains("cannot rollback migration"));
assert!(err.contains("file is missing"));
fs::remove_dir_all(root).unwrap();
}
#[test]
fn rollback_fails_when_down_sql_is_empty() {
let root = temp_path("shelly_data_empty_down_sql");
let migrations_dir = root.join("migrations");
fs::create_dir_all(&migrations_dir).unwrap();
fs::write(
migrations_dir.join("20260505150000_create_posts.sql"),
r#"
-- +up
CREATE TABLE posts(id BIGINT PRIMARY KEY);
-- +down
"#,
)
.unwrap();
let migrations = load_migrations(&migrations_dir).unwrap();
let engine = MigrationEngine::new(&root, AdapterKind::Sqlite).unwrap();
engine.migrate(&migrations, None).unwrap();
let err = engine.rollback(&migrations, 1).unwrap_err().to_string();
assert!(err.contains("empty down SQL"));
fs::remove_dir_all(root).unwrap();
}
fn temp_path(prefix: &str) -> PathBuf {
let nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_nanos();
std::env::temp_dir().join(format!("{prefix}_{nanos}"))
}
}