ic_sql_migrate/
lib.rs

1//! A lightweight SQLite migration library for Internet Computer (ICP) canisters.
2//!
3//! This library provides automatic database schema management and version control
4//! through SQL migration files that are embedded at compile time and executed at runtime.
5
6use rusqlite::Connection;
7use std::collections::HashSet;
8use thiserror::Error;
9
10/// Custom error type for migration operations.
11#[derive(Debug, Error)]
12pub enum Error {
13    /// Database operation failed
14    #[error("Database error: {0}")]
15    Database(#[from] rusqlite::Error),
16
17    /// I/O operation failed
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20
21    /// Migration execution failed
22    #[error("Migration '{id}' failed: {message}")]
23    MigrationFailed { id: String, message: String },
24
25    /// Environment variable not found
26    #[error("Environment variable '{0}' not set")]
27    EnvVarNotFound(String),
28}
29
30pub type MigrateResult<T> = std::result::Result<T, Error>;
31
32/// Represents a single database migration with its unique identifier and SQL content.
33#[derive(Debug, Clone)]
34pub struct Migration {
35    /// Unique identifier for the migration, derived from the filename
36    pub id: &'static str,
37    /// SQL statements to execute for this migration
38    pub sql: &'static str,
39}
40
41impl Migration {
42    /// Creates a new migration with the given ID and SQL content.
43    ///
44    /// # Arguments
45    /// * `id` - Unique identifier for the migration
46    /// * `sql` - SQL statements to execute
47    pub const fn new(id: &'static str, sql: &'static str) -> Self {
48        Self { id, sql }
49    }
50}
51
52/// Ensures the migrations tracking table exists in the database.
53///
54/// Creates a `_migrations` table if it doesn't exist, which tracks:
55/// - `id`: The unique identifier of each applied migration
56/// - `applied_at`: Timestamp when the migration was applied
57fn ensure_migrations_table(conn: &mut Connection) -> MigrateResult<()> {
58    conn.execute(
59        "CREATE TABLE IF NOT EXISTS _migrations (
60            id TEXT PRIMARY KEY,
61            applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
62        )",
63        [],
64    )?;
65    Ok(())
66}
67
68/// Retrieves the set of already applied migration IDs from the database.
69fn get_applied_migrations(conn: &Connection) -> MigrateResult<HashSet<String>> {
70    let mut statement = conn.prepare("SELECT id FROM _migrations")?;
71
72    let migration_ids = statement.query_map([], |row| row.get::<_, String>(0))?;
73
74    let mut applied_set = HashSet::new();
75    for id in migration_ids.into_iter().flatten() {
76        applied_set.insert(id);
77    }
78
79    Ok(applied_set)
80}
81
82/// Executes all pending migrations in order.
83///
84/// This function:
85/// 1. Ensures the migrations tracking table exists
86/// 2. Identifies which migrations have already been applied
87/// 3. Executes pending migrations in the order they appear in the slice
88/// 4. Records each migration as applied
89///
90/// All migrations are executed within a single transaction for atomicity.
91///
92/// # Arguments
93/// * `conn` - Mutable reference to the SQLite connection
94/// * `migrations` - Slice of migrations to apply in order
95///
96/// # Errors
97/// Returns an error if:
98/// - Database operations fail
99/// - Migration SQL is invalid
100/// - Transaction cannot be committed
101pub fn up(conn: &mut Connection, migrations: &[Migration]) -> MigrateResult<()> {
102    ensure_migrations_table(conn)?;
103    let applied_migrations = get_applied_migrations(conn)?;
104
105    // Check if there are any migrations to apply
106    let pending_migrations: Vec<&Migration> = migrations
107        .iter()
108        .filter(|m| !applied_migrations.contains(m.id))
109        .collect();
110
111    if pending_migrations.is_empty() {
112        return Ok(());
113    }
114
115    // Start transaction for all migrations
116    let tx = conn.transaction()?;
117
118    for migration in pending_migrations {
119        // Execute the migration SQL
120        tx.execute_batch(migration.sql)
121            .map_err(|e| Error::MigrationFailed {
122                id: migration.id.to_string(),
123                message: e.to_string(),
124            })?;
125
126        // Record migration as applied
127        tx.execute("INSERT INTO _migrations(id) VALUES (?)", [migration.id])?;
128    }
129
130    // Commit all migrations atomically
131    tx.commit()?;
132
133    Ok(())
134}
135
136/// Includes all migration files discovered by the `list` function at compile time.
137///
138/// This macro expands to a static slice of `Migration` structs containing
139/// all SQL files found in the migrations directory.
140///
141/// # Example
142/// ```ignore
143/// static MIGRATIONS: &[migrations::Migration] = ic_sql_migrate::include!();
144/// ```
145#[macro_export]
146macro_rules! include {
147    () => {
148        include!(concat!(env!("OUT_DIR"), "/migrations_gen.rs"))
149    };
150}
151
152/// Discovers and lists all SQL migration files for inclusion at compile time.
153///
154/// This function should be called in `build.rs` to generate code that embeds
155/// all migration files into the binary. It scans the specified directory for
156/// `.sql` files and generates Rust code to include them.
157///
158/// # Arguments
159/// * `migrations_dir_name` - Optional custom directory name (defaults to "migrations")
160///
161/// # Example
162/// ```no_run
163/// // In build.rs
164/// fn main() {
165///     migrations::list(Some("migrations")).unwrap();
166/// }
167/// ```
168///
169/// # Errors
170/// Returns an I/O error if:
171/// - The output directory cannot be written to
172/// - File system operations fail
173pub fn list(migrations_dir_name: Option<&str>) -> std::io::Result<()> {
174    use std::env;
175    use std::fs;
176    use std::path::Path;
177
178    let manifest_dir = env::var("CARGO_MANIFEST_DIR").map_err(|_| {
179        std::io::Error::new(std::io::ErrorKind::NotFound, "CARGO_MANIFEST_DIR not set")
180    })?;
181
182    let dir_name = migrations_dir_name.unwrap_or("migrations");
183    let migrations_dir = Path::new(&manifest_dir).join(dir_name);
184
185    // Ensure cargo rebuilds when migrations change
186    println!("cargo:rerun-if-changed={}", migrations_dir.display());
187
188    // Generate the output file path
189    let out_dir = env::var("OUT_DIR")
190        .map_err(|_| std::io::Error::new(std::io::ErrorKind::NotFound, "OUT_DIR not set"))?;
191    let dest_path = Path::new(&out_dir).join("migrations_gen.rs");
192
193    // If migrations directory doesn't exist, create empty migrations array
194    if !migrations_dir.exists() {
195        fs::write(dest_path, "&[]")?;
196        return Ok(());
197    }
198
199    // Collect all SQL files
200    let migration_files = collect_migration_files(&migrations_dir)?;
201
202    // Generate and write the Rust code
203    let generated_code = generate_migrations_code(&migration_files);
204    fs::write(dest_path, generated_code)?;
205
206    Ok(())
207}
208
209/// Collects all SQL migration files from the specified directory.
210///
211/// Returns a sorted list of (migration_id, file_path) tuples.
212fn collect_migration_files(
213    migrations_dir: &std::path::Path,
214) -> std::io::Result<Vec<(String, String)>> {
215    use std::fs;
216
217    let mut migration_files = Vec::new();
218
219    let entries = fs::read_dir(migrations_dir)?;
220    for entry in entries {
221        let entry = entry?;
222        let path = entry.path();
223
224        // Only process .sql files
225        if path.extension().and_then(|s| s.to_str()) != Some("sql") {
226            continue;
227        }
228
229        if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) {
230            let absolute_path = path.to_string_lossy().to_string();
231            migration_files.push((file_stem.to_string(), absolute_path));
232
233            // Ensure cargo rebuilds when this specific file changes
234            println!("cargo:rerun-if-changed={}", path.display());
235        }
236    }
237
238    // Sort migration files by name to ensure consistent ordering
239    migration_files.sort_by(|a, b| a.0.cmp(&b.0));
240
241    Ok(migration_files)
242}
243
244/// Generates Rust code for including migration files.
245///
246/// Creates a static array initialization with all migration files.
247fn generate_migrations_code(migration_files: &[(String, String)]) -> String {
248    let mut code = String::from("&[\n");
249
250    for (migration_id, file_path) in migration_files {
251        code.push_str(&format!(
252            "    ic_sql_migrate::Migration::new(\"{migration_id}\", include_str!(\"{file_path}\")),\n"
253        ));
254    }
255
256    code.push_str("]\n");
257    code
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use rusqlite::Connection;
264
265    #[test]
266    fn test_migration_creation() {
267        let migration = Migration::new("001_test", "CREATE TABLE test (id INTEGER);");
268        assert_eq!(migration.id, "001_test");
269        assert_eq!(migration.sql, "CREATE TABLE test (id INTEGER);");
270    }
271
272    #[test]
273    fn test_ensure_migrations_table() {
274        let mut conn = Connection::open_in_memory().unwrap();
275        ensure_migrations_table(&mut conn).unwrap();
276
277        // Verify table exists
278        let count: i64 = conn
279            .query_row(
280                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='_migrations'",
281                [],
282                |row| row.get(0),
283            )
284            .unwrap();
285        assert_eq!(count, 1);
286    }
287
288    #[test]
289    fn test_up_migrations() {
290        let mut conn = Connection::open_in_memory().unwrap();
291
292        let migrations = &[
293            Migration::new(
294                "001_create_users",
295                "CREATE TABLE users (id INTEGER PRIMARY KEY);",
296            ),
297            Migration::new("002_add_email", "ALTER TABLE users ADD COLUMN email TEXT;"),
298        ];
299
300        // Run migrations
301        up(&mut conn, migrations).unwrap();
302
303        // Verify migrations were applied
304        let applied = get_applied_migrations(&conn).unwrap();
305        assert!(applied.contains("001_create_users"));
306        assert!(applied.contains("002_add_email"));
307
308        // Verify table structure
309        let count: i64 = conn
310            .query_row(
311                "SELECT COUNT(*) FROM pragma_table_info('users') WHERE name='email'",
312                [],
313                |row| row.get(0),
314            )
315            .unwrap();
316        assert_eq!(count, 1);
317    }
318
319    #[test]
320    fn test_up_migrations_idempotency() {
321        let mut conn = Connection::open_in_memory().unwrap();
322
323        let migrations = &[Migration::new(
324            "001_test",
325            "CREATE TABLE test (id INTEGER);",
326        )];
327
328        // Run migrations twice
329        up(&mut conn, migrations).unwrap();
330        up(&mut conn, migrations).unwrap();
331
332        // Should only be applied once
333        let count: i64 = conn
334            .query_row(
335                "SELECT COUNT(*) FROM _migrations WHERE id='001_test'",
336                [],
337                |row| row.get(0),
338            )
339            .unwrap();
340        assert_eq!(count, 1);
341    }
342
343    #[test]
344    fn test_migration_failure_rollback() {
345        let mut conn = Connection::open_in_memory().unwrap();
346
347        let migrations = &[
348            Migration::new("001_valid", "CREATE TABLE test (id INTEGER);"),
349            Migration::new("002_invalid", "INVALID SQL STATEMENT;"),
350        ];
351
352        // Run migrations - should fail on second one
353        let result = up(&mut conn, migrations);
354        assert!(result.is_err());
355
356        // Verify first migration was not committed due to transaction rollback
357        let applied = get_applied_migrations(&conn).unwrap();
358        assert!(applied.is_empty());
359
360        // Verify table was not created
361        let count: i64 = conn
362            .query_row(
363                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='test'",
364                [],
365                |row| row.get(0),
366            )
367            .unwrap();
368        assert_eq!(count, 0);
369    }
370}