1use rusqlite::Connection;
7use std::collections::HashSet;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum Error {
13 #[error("Database error: {0}")]
15 Database(#[from] rusqlite::Error),
16
17 #[error("IO error: {0}")]
19 Io(#[from] std::io::Error),
20
21 #[error("Migration '{id}' failed: {message}")]
23 MigrationFailed { id: String, message: String },
24
25 #[error("Environment variable '{0}' not set")]
27 EnvVarNotFound(String),
28}
29
30pub type MigrateResult<T> = std::result::Result<T, Error>;
31
32#[derive(Debug, Clone)]
34pub struct Migration {
35 pub id: &'static str,
37 pub sql: &'static str,
39}
40
41impl Migration {
42 pub const fn new(id: &'static str, sql: &'static str) -> Self {
48 Self { id, sql }
49 }
50}
51
52fn ensure_migrations_table(conn: &mut Connection) -> MigrateResult<()> {
58 conn.execute(
59 "CREATE TABLE IF NOT EXISTS _migrations (
60 id TEXT PRIMARY KEY,
61 applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
62 )",
63 [],
64 )?;
65 Ok(())
66}
67
68fn get_applied_migrations(conn: &Connection) -> MigrateResult<HashSet<String>> {
70 let mut statement = conn.prepare("SELECT id FROM _migrations")?;
71
72 let migration_ids = statement.query_map([], |row| row.get::<_, String>(0))?;
73
74 let mut applied_set = HashSet::new();
75 for id in migration_ids.into_iter().flatten() {
76 applied_set.insert(id);
77 }
78
79 Ok(applied_set)
80}
81
82pub fn up(conn: &mut Connection, migrations: &[Migration]) -> MigrateResult<()> {
102 ensure_migrations_table(conn)?;
103 let applied_migrations = get_applied_migrations(conn)?;
104
105 let pending_migrations: Vec<&Migration> = migrations
107 .iter()
108 .filter(|m| !applied_migrations.contains(m.id))
109 .collect();
110
111 if pending_migrations.is_empty() {
112 return Ok(());
113 }
114
115 let tx = conn.transaction()?;
117
118 for migration in pending_migrations {
119 tx.execute_batch(migration.sql)
121 .map_err(|e| Error::MigrationFailed {
122 id: migration.id.to_string(),
123 message: e.to_string(),
124 })?;
125
126 tx.execute("INSERT INTO _migrations(id) VALUES (?)", [migration.id])?;
128 }
129
130 tx.commit()?;
132
133 Ok(())
134}
135
136#[macro_export]
146macro_rules! include {
147 () => {
148 include!(concat!(env!("OUT_DIR"), "/migrations_gen.rs"))
149 };
150}
151
152pub fn list(migrations_dir_name: Option<&str>) -> std::io::Result<()> {
174 use std::env;
175 use std::fs;
176 use std::path::Path;
177
178 let manifest_dir = env::var("CARGO_MANIFEST_DIR").map_err(|_| {
179 std::io::Error::new(std::io::ErrorKind::NotFound, "CARGO_MANIFEST_DIR not set")
180 })?;
181
182 let dir_name = migrations_dir_name.unwrap_or("migrations");
183 let migrations_dir = Path::new(&manifest_dir).join(dir_name);
184
185 println!("cargo:rerun-if-changed={}", migrations_dir.display());
187
188 let out_dir = env::var("OUT_DIR")
190 .map_err(|_| std::io::Error::new(std::io::ErrorKind::NotFound, "OUT_DIR not set"))?;
191 let dest_path = Path::new(&out_dir).join("migrations_gen.rs");
192
193 if !migrations_dir.exists() {
195 fs::write(dest_path, "&[]")?;
196 return Ok(());
197 }
198
199 let migration_files = collect_migration_files(&migrations_dir)?;
201
202 let generated_code = generate_migrations_code(&migration_files);
204 fs::write(dest_path, generated_code)?;
205
206 Ok(())
207}
208
209fn collect_migration_files(
213 migrations_dir: &std::path::Path,
214) -> std::io::Result<Vec<(String, String)>> {
215 use std::fs;
216
217 let mut migration_files = Vec::new();
218
219 let entries = fs::read_dir(migrations_dir)?;
220 for entry in entries {
221 let entry = entry?;
222 let path = entry.path();
223
224 if path.extension().and_then(|s| s.to_str()) != Some("sql") {
226 continue;
227 }
228
229 if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) {
230 let absolute_path = path.to_string_lossy().to_string();
231 migration_files.push((file_stem.to_string(), absolute_path));
232
233 println!("cargo:rerun-if-changed={}", path.display());
235 }
236 }
237
238 migration_files.sort_by(|a, b| a.0.cmp(&b.0));
240
241 Ok(migration_files)
242}
243
244fn generate_migrations_code(migration_files: &[(String, String)]) -> String {
248 let mut code = String::from("&[\n");
249
250 for (migration_id, file_path) in migration_files {
251 code.push_str(&format!(
252 " ic_sql_migrate::Migration::new(\"{migration_id}\", include_str!(\"{file_path}\")),\n"
253 ));
254 }
255
256 code.push_str("]\n");
257 code
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use rusqlite::Connection;
264
265 #[test]
266 fn test_migration_creation() {
267 let migration = Migration::new("001_test", "CREATE TABLE test (id INTEGER);");
268 assert_eq!(migration.id, "001_test");
269 assert_eq!(migration.sql, "CREATE TABLE test (id INTEGER);");
270 }
271
272 #[test]
273 fn test_ensure_migrations_table() {
274 let mut conn = Connection::open_in_memory().unwrap();
275 ensure_migrations_table(&mut conn).unwrap();
276
277 let count: i64 = conn
279 .query_row(
280 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='_migrations'",
281 [],
282 |row| row.get(0),
283 )
284 .unwrap();
285 assert_eq!(count, 1);
286 }
287
288 #[test]
289 fn test_up_migrations() {
290 let mut conn = Connection::open_in_memory().unwrap();
291
292 let migrations = &[
293 Migration::new(
294 "001_create_users",
295 "CREATE TABLE users (id INTEGER PRIMARY KEY);",
296 ),
297 Migration::new("002_add_email", "ALTER TABLE users ADD COLUMN email TEXT;"),
298 ];
299
300 up(&mut conn, migrations).unwrap();
302
303 let applied = get_applied_migrations(&conn).unwrap();
305 assert!(applied.contains("001_create_users"));
306 assert!(applied.contains("002_add_email"));
307
308 let count: i64 = conn
310 .query_row(
311 "SELECT COUNT(*) FROM pragma_table_info('users') WHERE name='email'",
312 [],
313 |row| row.get(0),
314 )
315 .unwrap();
316 assert_eq!(count, 1);
317 }
318
319 #[test]
320 fn test_up_migrations_idempotency() {
321 let mut conn = Connection::open_in_memory().unwrap();
322
323 let migrations = &[Migration::new(
324 "001_test",
325 "CREATE TABLE test (id INTEGER);",
326 )];
327
328 up(&mut conn, migrations).unwrap();
330 up(&mut conn, migrations).unwrap();
331
332 let count: i64 = conn
334 .query_row(
335 "SELECT COUNT(*) FROM _migrations WHERE id='001_test'",
336 [],
337 |row| row.get(0),
338 )
339 .unwrap();
340 assert_eq!(count, 1);
341 }
342
343 #[test]
344 fn test_migration_failure_rollback() {
345 let mut conn = Connection::open_in_memory().unwrap();
346
347 let migrations = &[
348 Migration::new("001_valid", "CREATE TABLE test (id INTEGER);"),
349 Migration::new("002_invalid", "INVALID SQL STATEMENT;"),
350 ];
351
352 let result = up(&mut conn, migrations);
354 assert!(result.is_err());
355
356 let applied = get_applied_migrations(&conn).unwrap();
358 assert!(applied.is_empty());
359
360 let count: i64 = conn
362 .query_row(
363 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test'",
364 [],
365 |row| row.get(0),
366 )
367 .unwrap();
368 assert_eq!(count, 0);
369 }
370}