1use std::path::{Path, PathBuf};
4use std::fs;
5use chrono::Utc;
6use sqlx::Row;
7
8#[derive(Debug)]
10pub struct Migration {
11 pub version: String,
12 pub name: String,
13 pub up_sql: String,
14 pub down_sql: String,
15}
16
17impl Migration {
18 pub fn new(name: &str) -> Self {
19 let timestamp = Utc::now().format("%Y%m%d%H%M%S").to_string();
20 let version = format!("{}_{}", timestamp, name);
21
22 Self {
23 version,
24 name: name.to_string(),
25 up_sql: String::new(),
26 down_sql: String::new(),
27 }
28 }
29
30 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
31 let path = path.as_ref();
32 let content = fs::read_to_string(path)?;
33
34 let filename = path.file_stem()
35 .and_then(|s| s.to_str())
36 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid filename"))?;
37
38 let parts: Vec<&str> = filename.splitn(2, '_').collect();
40 if parts.len() != 2 {
41 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid migration filename"));
42 }
43
44 let version = filename.to_string();
45 let name = parts[1].to_string();
46
47 let sections: Vec<&str> = content.split("-- migrate:down").collect();
49 let up_sql = sections.get(0)
50 .unwrap_or(&"")
51 .replace("-- migrate:up", "")
52 .trim()
53 .to_string();
54 let down_sql = sections.get(1)
55 .unwrap_or(&"")
56 .trim()
57 .to_string();
58
59 Ok(Self {
60 version,
61 name,
62 up_sql,
63 down_sql,
64 })
65 }
66
67 pub fn save(&self, migrations_dir: impl AsRef<Path>) -> Result<PathBuf, std::io::Error> {
68 let migrations_dir = migrations_dir.as_ref();
69 fs::create_dir_all(migrations_dir)?;
70
71 let filename = format!("{}.sql", self.version);
72 let path = migrations_dir.join(filename);
73
74 let content = format!(
75 "-- migrate:up\n{}\n\n-- migrate:down\n{}\n",
76 self.up_sql,
77 self.down_sql
78 );
79
80 fs::write(&path, content)?;
81 Ok(path)
82 }
83}
84
85pub struct MigrationManager {
87 migrations_dir: PathBuf,
88}
89
90impl MigrationManager {
91 pub fn new(migrations_dir: impl AsRef<Path>) -> Self {
92 Self {
93 migrations_dir: migrations_dir.as_ref().to_path_buf(),
94 }
95 }
96
97 pub fn list_migrations(&self) -> Result<Vec<Migration>, std::io::Error> {
98 let mut migrations = Vec::new();
99
100 if !self.migrations_dir.exists() {
101 return Ok(migrations);
102 }
103
104 for entry in fs::read_dir(&self.migrations_dir)? {
105 let entry = entry?;
106 let path = entry.path();
107
108 if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("sql") {
109 if let Ok(migration) = Migration::from_file(&path) {
110 migrations.push(migration);
111 }
112 }
113 }
114
115 migrations.sort_by(|a, b| a.version.cmp(&b.version));
117
118 Ok(migrations)
119 }
120
121 pub fn create_migration(&self, name: &str) -> Result<PathBuf, std::io::Error> {
122 let migration = Migration::new(name);
123 migration.save(&self.migrations_dir)
124 }
125
126 pub async fn ensure_migrations_table(&self, db: &impl crate::Database) -> crate::Result<()> {
128 let sql = r#"
129 CREATE TABLE IF NOT EXISTS _migrations (
130 id INTEGER PRIMARY KEY AUTOINCREMENT,
131 version TEXT NOT NULL UNIQUE,
132 applied_at INTEGER NOT NULL
133 )
134 "#;
135 db.execute(sql).await?;
136 Ok(())
137 }
138
139 pub async fn get_applied_migrations(&self, db: &impl crate::Database) -> crate::Result<Vec<String>> {
141 self.ensure_migrations_table(db).await?;
142
143 let rows = db.query("SELECT version FROM _migrations ORDER BY version").await?;
144 let mut versions = Vec::new();
145
146 for row in rows {
147 if let Ok(version) = row.try_get::<String, _>("version") {
148 versions.push(version);
149 }
150 }
151
152 Ok(versions)
153 }
154
155 pub async fn mark_migration_applied(&self, db: &impl crate::Database, version: &str) -> crate::Result<()> {
157 self.ensure_migrations_table(db).await?;
158
159 let timestamp = chrono::Utc::now().timestamp();
160 let sql = format!(
161 "INSERT INTO _migrations (version, applied_at) VALUES ('{}', {})",
162 version, timestamp
163 );
164 db.execute(&sql).await?;
165 Ok(())
166 }
167
168 pub async fn mark_migration_reverted(&self, db: &impl crate::Database, version: &str) -> crate::Result<()> {
170 let sql = format!("DELETE FROM _migrations WHERE version = '{}'", version);
171 db.execute(&sql).await?;
172 Ok(())
173 }
174
175 pub async fn get_pending_migrations(&self, db: &impl crate::Database) -> Result<Vec<Migration>, Box<dyn std::error::Error>> {
177 let all_migrations = self.list_migrations()?;
178 let applied = self.get_applied_migrations(db).await?;
179
180 let pending: Vec<Migration> = all_migrations
181 .into_iter()
182 .filter(|m| !applied.contains(&m.version))
183 .collect();
184
185 Ok(pending)
186 }
187}