oxify_storage/
migration_runner.rs1use crate::{DatabasePool, Result, StorageError};
35use chrono::{DateTime, Utc};
36use sqlx::Row;
37use std::collections::HashMap;
38
39#[derive(Debug, Clone)]
41pub struct Migration {
42 pub version: String,
44 pub name: String,
46 pub up_sql: String,
48 pub down_sql: Option<String>,
50 pub description: Option<String>,
52}
53
54#[derive(Debug, Clone)]
56pub struct MigrationStatus {
57 pub version: String,
59 pub name: String,
61 pub applied: bool,
63 pub applied_at: Option<DateTime<Utc>>,
65 pub checksum: Option<String>,
67}
68
69#[derive(Debug, Clone)]
71#[allow(dead_code)]
72struct MigrationRecord {
73 version: String,
74 name: String,
75 applied_at: String,
76 checksum: String,
77}
78
79pub struct MigrationRunner {
81 pool: DatabasePool,
82}
83
84impl MigrationRunner {
85 pub fn new(pool: DatabasePool) -> Self {
87 Self { pool }
88 }
89
90 pub async fn initialize(&self) -> Result<()> {
94 sqlx::query(
95 r"
96 CREATE TABLE IF NOT EXISTS schema_migrations (
97 version TEXT PRIMARY KEY,
98 name TEXT NOT NULL,
99 applied_at TEXT NOT NULL DEFAULT (datetime('now')),
100 checksum TEXT NOT NULL,
101 execution_time_ms INTEGER
102 )
103 ",
104 )
105 .execute(self.pool.pool())
106 .await?;
107
108 sqlx::query(
110 r"
111 CREATE INDEX IF NOT EXISTS idx_schema_migrations_applied_at
112 ON schema_migrations(applied_at DESC)
113 ",
114 )
115 .execute(self.pool.pool())
116 .await?;
117
118 Ok(())
119 }
120
121 pub fn get_builtin_migrations() -> Vec<Migration> {
125 vec![
126 Migration {
127 version: "001_performance_indexes".to_string(),
128 name: "Add Performance Indexes".to_string(),
129 description: Some("Creates indexes for improved query performance".to_string()),
130 up_sql: include_str!("../sql/001_performance_indexes.sql").to_string(),
131 down_sql: Some(include_str!("../sql/001_performance_indexes_down.sql").to_string()),
132 },
133 Migration {
134 version: "002_schema_constraints".to_string(),
135 name: "Add Schema Constraints".to_string(),
136 description: Some("Adds foreign keys and CHECK constraints".to_string()),
137 up_sql: include_str!("../sql/002_schema_constraints.sql").to_string(),
138 down_sql: Some(include_str!("../sql/002_schema_constraints_down.sql").to_string()),
139 },
140 ]
141 }
142
143 pub async fn is_applied(&self, version: &str) -> Result<bool> {
145 let row = sqlx::query(
146 r"
147 SELECT COUNT(*) as count FROM schema_migrations WHERE version = ?
148 ",
149 )
150 .bind(version)
151 .fetch_one(self.pool.pool())
152 .await?;
153
154 let count: i64 = row.get("count");
155 Ok(count > 0)
156 }
157
158 pub async fn apply_migration(&self, migration: &Migration) -> Result<i64> {
162 if self.is_applied(&migration.version).await? {
164 return Ok(0);
165 }
166
167 let start_time = std::time::Instant::now();
168 let mut tx = self.pool.pool().begin().await?;
169
170 let up_sql: &'static str = Box::leak(migration.up_sql.clone().into_boxed_str());
173 sqlx::query(up_sql).execute(&mut *tx).await?;
174
175 let checksum = Self::calculate_checksum(&migration.up_sql);
177
178 let execution_time_ms = start_time.elapsed().as_millis() as i64;
180 sqlx::query(
181 r"
182 INSERT INTO schema_migrations (version, name, checksum, execution_time_ms)
183 VALUES (?, ?, ?, ?)
184 ",
185 )
186 .bind(&migration.version)
187 .bind(&migration.name)
188 .bind(&checksum)
189 .bind(execution_time_ms)
190 .execute(&mut *tx)
191 .await?;
192
193 tx.commit().await?;
194
195 Ok(execution_time_ms)
196 }
197
198 pub async fn run_pending_migrations(&self) -> Result<Vec<String>> {
202 self.initialize().await?;
203
204 let migrations = Self::get_builtin_migrations();
205 let mut applied = Vec::new();
206
207 for migration in migrations {
208 if !self.is_applied(&migration.version).await? {
209 self.apply_migration(&migration).await?;
210 applied.push(migration.version.clone());
211 }
212 }
213
214 Ok(applied)
215 }
216
217 pub async fn get_migration_status(&self) -> Result<Vec<MigrationStatus>> {
219 self.initialize().await?;
220
221 let migrations = Self::get_builtin_migrations();
222 let applied_records = self.get_applied_migrations().await?;
223
224 let applied_map: HashMap<String, MigrationRecord> = applied_records
225 .into_iter()
226 .map(|r| (r.version.clone(), r))
227 .collect();
228
229 let mut status = Vec::new();
230 for migration in migrations {
231 let record = applied_map.get(&migration.version);
232 status.push(MigrationStatus {
233 version: migration.version,
234 name: migration.name,
235 applied: record.is_some(),
236 applied_at: record.and_then(|r| {
237 DateTime::parse_from_rfc3339(&r.applied_at)
238 .ok()
239 .map(|dt| dt.with_timezone(&Utc))
240 }),
241 checksum: record.map(|r| r.checksum.clone()),
242 });
243 }
244
245 Ok(status)
246 }
247
248 async fn get_applied_migrations(&self) -> Result<Vec<MigrationRecord>> {
250 let rows = sqlx::query(
251 r"
252 SELECT version, name, applied_at, checksum
253 FROM schema_migrations
254 ORDER BY applied_at ASC
255 ",
256 )
257 .fetch_all(self.pool.pool())
258 .await?;
259
260 let records: Vec<MigrationRecord> = rows
261 .into_iter()
262 .map(|row| MigrationRecord {
263 version: row.get("version"),
264 name: row.get("name"),
265 applied_at: row.get("applied_at"),
266 checksum: row.get("checksum"),
267 })
268 .collect();
269
270 Ok(records)
271 }
272
273 pub async fn rollback_migration(&self, version: &str) -> Result<()> {
277 let migrations = Self::get_builtin_migrations();
278 let migration = migrations
279 .iter()
280 .find(|m| m.version == version)
281 .ok_or_else(|| {
282 StorageError::NotFoundLegacy(format!("Migration {version} not found"))
283 })?;
284
285 let down_sql = migration.down_sql.as_ref().ok_or_else(|| {
286 StorageError::ValidationError(format!("Migration {version} has no rollback defined"))
287 })?;
288
289 let mut tx = self.pool.pool().begin().await?;
290
291 let down_sql_static: &'static str = Box::leak(down_sql.clone().into_boxed_str());
294 sqlx::query(down_sql_static).execute(&mut *tx).await?;
295
296 sqlx::query("DELETE FROM schema_migrations WHERE version = ?")
298 .bind(version)
299 .execute(&mut *tx)
300 .await?;
301
302 tx.commit().await?;
303
304 Ok(())
305 }
306
307 fn calculate_checksum(sql: &str) -> String {
309 use sha2::{Digest, Sha256};
310 let mut hasher = Sha256::new();
311 hasher.update(sql.as_bytes());
312 format!("{:x}", hasher.finalize())
313 }
314
315 pub async fn verify_checksums(&self) -> Result<Vec<String>> {
319 let migrations = Self::get_builtin_migrations();
320 let applied = self.get_applied_migrations().await?;
321 let mut mismatches = Vec::new();
322
323 let applied_map: HashMap<String, MigrationRecord> = applied
324 .into_iter()
325 .map(|r| (r.version.clone(), r))
326 .collect();
327
328 for migration in migrations {
329 if let Some(record) = applied_map.get(&migration.version) {
330 let current_checksum = Self::calculate_checksum(&migration.up_sql);
331 if current_checksum != record.checksum {
332 mismatches.push(migration.version.clone());
333 }
334 }
335 }
336
337 Ok(mismatches)
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_checksum_calculation() {
347 let sql = "CREATE TABLE test (id INT)";
348 let checksum = MigrationRunner::calculate_checksum(sql);
349 assert_eq!(checksum.len(), 64); }
351
352 #[test]
353 fn test_checksum_consistency() {
354 let sql = "SELECT * FROM users";
355 let checksum1 = MigrationRunner::calculate_checksum(sql);
356 let checksum2 = MigrationRunner::calculate_checksum(sql);
357 assert_eq!(checksum1, checksum2);
358 }
359
360 #[test]
361 fn test_builtin_migrations_exist() {
362 let migrations = MigrationRunner::get_builtin_migrations();
363 assert!(!migrations.is_empty());
364 assert!(migrations
365 .iter()
366 .any(|m| m.version.contains("performance_indexes")));
367 assert!(migrations
368 .iter()
369 .any(|m| m.version.contains("schema_constraints")));
370 }
371}