Skip to main content

duroxide_pg/
migrations.rs

1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::Connection;
4use sqlx::PgPool;
5use std::sync::Arc;
6
7static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
8
9/// Migration metadata
10#[derive(Debug)]
11struct Migration {
12    version: i64,
13    name: String,
14    sql: String,
15}
16
17/// Migration runner that handles schema-qualified migrations
18pub struct MigrationRunner {
19    pool: Arc<PgPool>,
20    schema_name: String,
21}
22
23impl MigrationRunner {
24    /// Create a new migration runner
25    pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
26        Self { pool, schema_name }
27    }
28
29    fn advisory_lock_key(&self) -> i64 {
30        // Stable 64-bit FNV-1a hash over (namespace + schema name).
31        // This avoids using Rust's DefaultHasher (randomized per-process).
32        const OFFSET: u64 = 0xcbf29ce484222325;
33        const PRIME: u64 = 0x100000001b3;
34
35        let mut hash = OFFSET;
36        for b in b"duroxide_pg:migrations:" {
37            hash ^= *b as u64;
38            hash = hash.wrapping_mul(PRIME);
39        }
40        for b in self.schema_name.as_bytes() {
41            hash ^= *b as u64;
42            hash = hash.wrapping_mul(PRIME);
43        }
44
45        hash as i64
46    }
47
48    async fn lock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
49        let key = self.advisory_lock_key();
50        // Session lock (not xact lock) so it spans multiple transactions.
51        // We explicitly unlock at the end.
52        sqlx::query("SELECT pg_advisory_lock($1)")
53            .bind(key)
54            .execute(&mut *conn)
55            .await?;
56        Ok(())
57    }
58
59    async fn unlock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) {
60        let key = self.advisory_lock_key();
61        // Best-effort unlock.
62        let _ = sqlx::query("SELECT pg_advisory_unlock($1)")
63            .bind(key)
64            .execute(&mut *conn)
65            .await;
66    }
67
68    /// Run all pending migrations
69    pub async fn migrate(&self) -> Result<()> {
70        let mut conn = self.pool.acquire().await?;
71        let conn = &mut *conn;
72        self.lock_for_migrations(conn).await?;
73
74        // Reject unknown migrations while holding the advisory lock so that an
75        // older binary cannot rewrite a schema that is ahead of its code.
76        // Short-circuit: do NOT run migrate_inner if unknown migrations are
77        // detected.
78        let result = match self.check_no_unknown_migrations(conn).await {
79            Ok(()) => self.migrate_inner(conn).await,
80            Err(e) => Err(e),
81        };
82        self.unlock_for_migrations(conn).await;
83
84        result
85    }
86
87    /// Verify that the migration tracking table exists and that every embedded
88    /// migration has already been applied. Does not take the migration
89    /// advisory lock and does not create or modify any database objects.
90    ///
91    /// Returns an error if the `_duroxide_migrations` table is missing in
92    /// `schema_name` or if any bundled migration version is absent from it.
93    ///
94    /// Intended for processes that must not perform DDL (e.g. application
95    /// backends, where a separately privileged worker is responsible for
96    /// applying schema changes).
97    pub async fn verify(&self) -> Result<()> {
98        let mut conn = self.pool.acquire().await?;
99        let conn = &mut *conn;
100
101        // Check that the tracking table exists in the target schema.
102        let table_exists: bool = sqlx::query_scalar(
103            "SELECT EXISTS(SELECT 1 FROM information_schema.tables \
104             WHERE table_schema = $1 AND table_name = '_duroxide_migrations')",
105        )
106        .bind(&self.schema_name)
107        .fetch_one(&mut *conn)
108        .await?;
109
110        if !table_exists {
111            anyhow::bail!(
112                "duroxide migrations not initialized: schema {:?} does not \
113                 contain _duroxide_migrations. Construct a provider with \
114                 MigrationPolicy::ApplyAll (the default) from a process with \
115                 DDL privileges before using MigrationPolicy::VerifyOnly.",
116                self.schema_name
117            );
118        }
119
120        // Reject schemas that have migrations the running binary does not
121        // recognize (schema is ahead of the code).
122        self.check_no_unknown_migrations(conn).await?;
123
124        let migrations = self.load_migrations()?;
125        let applied: std::collections::HashSet<i64> =
126            self.get_applied_versions(conn).await?.into_iter().collect();
127
128        let mut missing: Vec<i64> = migrations
129            .iter()
130            .map(|m| m.version)
131            .filter(|v| !applied.contains(v))
132            .collect();
133        missing.sort_unstable();
134
135        if !missing.is_empty() {
136            anyhow::bail!(
137                "duroxide migrations not up to date in schema {:?}: missing \
138                 versions {:?}. Run migrations from a provider configured with \
139                 MigrationPolicy::ApplyAll before constructing VerifyOnly \
140                 providers.",
141                self.schema_name,
142                missing,
143            );
144        }
145
146        if !self.check_tables_exist(conn).await.unwrap_or(false) {
147            anyhow::bail!(
148                "duroxide migrations recorded as complete in schema {:?}, but \
149                 core tables are missing. The schema may be corrupted; run \
150                 migrations from a provider configured with \
151                 MigrationPolicy::ApplyAll before constructing VerifyOnly \
152                 providers.",
153                self.schema_name,
154            );
155        }
156
157        Ok(())
158    }
159
160    /// Check that the database has no migrations the running binary does not
161    /// recognize. Used by both `migrate()` (to refuse running DDL against a
162    /// schema ahead of the code) and `verify()` (to refuse claiming
163    /// successful verification of an unknown schema).
164    ///
165    /// Returns `Ok(())` if the tracking table does not yet exist: under
166    /// `ApplyAll` it will be created by `migrate_inner`, and under
167    /// `VerifyOnly` the missing table is reported separately before this is
168    /// called.
169    async fn check_no_unknown_migrations(
170        &self,
171        conn: &mut sqlx::postgres::PgConnection,
172    ) -> Result<()> {
173        let tracking_exists: bool = sqlx::query_scalar(
174            "SELECT EXISTS(SELECT 1 FROM information_schema.tables \
175             WHERE table_schema = $1 AND table_name = '_duroxide_migrations')",
176        )
177        .bind(&self.schema_name)
178        .fetch_one(&mut *conn)
179        .await?;
180
181        if !tracking_exists {
182            return Ok(());
183        }
184
185        let applied = self.get_applied_versions(conn).await?;
186        let expected: std::collections::HashSet<i64> = self
187            .load_migrations()?
188            .into_iter()
189            .map(|m| m.version)
190            .collect();
191
192        let mut unknown: Vec<i64> = applied
193            .into_iter()
194            .filter(|v| !expected.contains(v))
195            .collect();
196        unknown.sort_unstable();
197
198        if !unknown.is_empty() {
199            anyhow::bail!(
200                "schema {:?} has migrations not recognized by this version of \
201                 the code: {:?}. The database schema is ahead of the code. \
202                 Update the code to a compatible version.",
203                self.schema_name,
204                unknown,
205            );
206        }
207
208        Ok(())
209    }
210
211    async fn migrate_inner(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
212        // Ensure schema exists
213        if self.schema_name != "public" {
214            sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
215                .execute(&mut *conn)
216                .await?;
217        }
218
219        // Load migrations from filesystem
220        let migrations = self.load_migrations()?;
221
222        tracing::debug!(
223            "Loaded {} migrations for schema {}",
224            migrations.len(),
225            self.schema_name
226        );
227
228        // Ensure migration tracking table exists (in the schema)
229        self.ensure_migration_table(conn).await?;
230
231        // Get applied migrations
232        let applied_versions = self.get_applied_versions(conn).await?;
233
234        tracing::debug!("Applied migrations: {:?}", applied_versions);
235
236        // Check if key tables exist - if not, we need to re-run migrations even if marked as applied
237        // This handles the case where cleanup dropped tables but not the migration tracking table
238        let tables_exist = self.check_tables_exist(conn).await.unwrap_or(false);
239
240        // Apply pending migrations (or re-apply if tables don't exist)
241        for migration in migrations {
242            let should_apply = if !applied_versions.contains(&migration.version) {
243                true // New migration
244            } else if !tables_exist {
245                // Migration was applied but tables don't exist - re-apply
246                tracing::warn!(
247                    "Migration {} is marked as applied but tables don't exist, re-applying",
248                    migration.version
249                );
250                // Remove the old migration record so we can re-apply
251                sqlx::query(&format!(
252                    "DELETE FROM {}._duroxide_migrations WHERE version = $1",
253                    self.schema_name
254                ))
255                .bind(migration.version)
256                .execute(&mut *conn)
257                .await?;
258                true
259            } else {
260                false // Already applied and tables exist
261            };
262
263            if should_apply {
264                tracing::debug!(
265                    "Applying migration {}: {}",
266                    migration.version,
267                    migration.name
268                );
269                self.apply_migration(conn, &migration).await?;
270            } else {
271                tracing::debug!(
272                    "Skipping migration {}: {} (already applied)",
273                    migration.version,
274                    migration.name
275                );
276            }
277        }
278
279        Ok(())
280    }
281
282    /// Load migrations from the embedded migrations directory
283    fn load_migrations(&self) -> Result<Vec<Migration>> {
284        let mut migrations = Vec::new();
285
286        // Get all files from embedded directory
287        let mut files: Vec<_> = MIGRATIONS
288            .files()
289            .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
290            .collect();
291
292        // Sort by path to ensure consistent ordering
293        files.sort_by_key(|f| f.path());
294
295        for file in files {
296            let file_name = file
297                .path()
298                .file_name()
299                .and_then(|n| n.to_str())
300                .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
301
302            let sql = file
303                .contents_utf8()
304                .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
305                .to_string();
306
307            let version = self.parse_version(file_name)?;
308            let name = file_name.to_string();
309
310            migrations.push(Migration { version, name, sql });
311        }
312
313        Ok(migrations)
314    }
315
316    /// Parse version number from migration filename (e.g., "0001_initial.sql" -> 1)
317    fn parse_version(&self, filename: &str) -> Result<i64> {
318        let version_str = filename
319            .split('_')
320            .next()
321            .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
322
323        version_str
324            .parse::<i64>()
325            .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
326    }
327
328    /// Ensure migration tracking table exists
329    async fn ensure_migration_table(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
330        // Create migration table in the target schema
331        sqlx::query(&format!(
332            r#"
333            CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
334                version BIGINT PRIMARY KEY,
335                name TEXT NOT NULL,
336                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
337            )
338            "#,
339            self.schema_name
340        ))
341        .execute(&mut *conn)
342        .await?;
343
344        Ok(())
345    }
346
347    /// Check if key tables exist
348    async fn check_tables_exist(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<bool> {
349        // Check if instances table exists (as a proxy for all tables)
350        let exists: bool = sqlx::query_scalar(
351            "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
352        )
353        .bind(&self.schema_name)
354        .fetch_one(&mut *conn)
355        .await?;
356
357        Ok(exists)
358    }
359
360    /// Get list of applied migration versions
361    async fn get_applied_versions(
362        &self,
363        conn: &mut sqlx::postgres::PgConnection,
364    ) -> Result<Vec<i64>> {
365        let versions: Vec<i64> = sqlx::query_scalar(&format!(
366            "SELECT version FROM {}._duroxide_migrations ORDER BY version",
367            self.schema_name
368        ))
369        .fetch_all(&mut *conn)
370        .await?;
371
372        Ok(versions)
373    }
374
375    /// Split SQL into statements, respecting dollar-quoted strings ($$...$$)
376    /// This handles stored procedures and other constructs that use dollar-quoting
377    fn split_sql_statements(sql: &str) -> Vec<String> {
378        let mut statements = Vec::new();
379        let mut current_statement = String::new();
380        let chars: Vec<char> = sql.chars().collect();
381        let mut i = 0;
382        let mut in_dollar_quote = false;
383        let mut dollar_tag: Option<String> = None;
384
385        while i < chars.len() {
386            let ch = chars[i];
387
388            if !in_dollar_quote {
389                // Check for start of dollar-quoted string
390                if ch == '$' {
391                    let mut tag = String::new();
392                    tag.push(ch);
393                    i += 1;
394
395                    // Collect the tag (e.g., $$, $tag$, $function$)
396                    while i < chars.len() {
397                        let next_ch = chars[i];
398                        if next_ch == '$' {
399                            tag.push(next_ch);
400                            dollar_tag = Some(tag.clone());
401                            in_dollar_quote = true;
402                            current_statement.push_str(&tag);
403                            i += 1;
404                            break;
405                        } else if next_ch.is_alphanumeric() || next_ch == '_' {
406                            tag.push(next_ch);
407                            i += 1;
408                        } else {
409                            // Not a dollar quote, just a $ character
410                            current_statement.push(ch);
411                            break;
412                        }
413                    }
414                } else if ch == ';' {
415                    // End of statement (only if not in dollar quote)
416                    current_statement.push(ch);
417                    let trimmed = current_statement.trim().to_string();
418                    if !trimmed.is_empty() {
419                        statements.push(trimmed);
420                    }
421                    current_statement.clear();
422                    i += 1;
423                } else {
424                    current_statement.push(ch);
425                    i += 1;
426                }
427            } else {
428                // Inside dollar-quoted string
429                current_statement.push(ch);
430
431                // Check for end of dollar-quoted string
432                if ch == '$' {
433                    let tag = dollar_tag.as_ref().unwrap();
434                    let mut matches = true;
435
436                    // Check if the following characters match the closing tag
437                    for (j, tag_char) in tag.chars().enumerate() {
438                        if j == 0 {
439                            continue; // Skip first $ (we already matched it)
440                        }
441                        if i + j >= chars.len() || chars[i + j] != tag_char {
442                            matches = false;
443                            break;
444                        }
445                    }
446
447                    if matches {
448                        // Found closing tag - consume remaining tag characters
449                        for _ in 0..(tag.len() - 1) {
450                            if i + 1 < chars.len() {
451                                current_statement.push(chars[i + 1]);
452                                i += 1;
453                            }
454                        }
455                        in_dollar_quote = false;
456                        dollar_tag = None;
457                    }
458                }
459                i += 1;
460            }
461        }
462
463        // Add remaining statement if any
464        let trimmed = current_statement.trim().to_string();
465        if !trimmed.is_empty() {
466            statements.push(trimmed);
467        }
468
469        statements
470    }
471
472    /// Apply a single migration
473    async fn apply_migration(
474        &self,
475        conn: &mut sqlx::postgres::PgConnection,
476        migration: &Migration,
477    ) -> Result<()> {
478        // Start transaction
479        let mut tx = conn.begin().await?;
480
481        // Set search_path for this transaction
482        sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
483            .execute(&mut *tx)
484            .await?;
485
486        // Remove comment lines and split SQL into individual statements
487        let sql = migration.sql.trim();
488        let cleaned_sql: String = sql
489            .lines()
490            .map(|line| {
491                // Remove full-line comments
492                if let Some(idx) = line.find("--") {
493                    // Check if -- is inside a string (simple check)
494                    let before = &line[..idx];
495                    if before.matches('\'').count() % 2 == 0 {
496                        // Even number of quotes means -- is not in a string
497                        line[..idx].trim()
498                    } else {
499                        line
500                    }
501                } else {
502                    line
503                }
504            })
505            .filter(|line| !line.is_empty())
506            .collect::<Vec<_>>()
507            .join("\n");
508
509        // Split by semicolon, but respect dollar-quoted strings ($$...$$)
510        let statements = Self::split_sql_statements(&cleaned_sql);
511
512        tracing::debug!(
513            "Executing {} statements for migration {}",
514            statements.len(),
515            migration.version
516        );
517
518        for (idx, statement) in statements.iter().enumerate() {
519            if !statement.trim().is_empty() {
520                tracing::debug!(
521                    "Executing statement {} of {}: {}...",
522                    idx + 1,
523                    statements.len(),
524                    &statement.chars().take(50).collect::<String>()
525                );
526                sqlx::query(statement)
527                    .execute(&mut *tx)
528                    .await
529                    .map_err(|e| {
530                        anyhow::anyhow!(
531                            "Failed to execute statement {} in migration {}: {}\nStatement: {}",
532                            idx + 1,
533                            migration.version,
534                            e,
535                            statement
536                        )
537                    })?;
538            }
539        }
540
541        // Record migration as applied
542        sqlx::query(&format!(
543            "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
544            self.schema_name
545        ))
546        .bind(migration.version)
547        .bind(&migration.name)
548        .execute(&mut *tx)
549        .await?;
550
551        // Commit transaction
552        tx.commit().await?;
553
554        tracing::info!(
555            "Applied migration {}: {}",
556            migration.version,
557            migration.name
558        );
559
560        Ok(())
561    }
562}