Skip to main content

duroxide_pg/
migrations.rs

1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::Connection;
4use sqlx::PgPool;
5use std::sync::Arc;
6
7static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
8
9/// Migration metadata
10#[derive(Debug)]
11struct Migration {
12    version: i64,
13    name: String,
14    sql: String,
15}
16
17/// Migration runner that handles schema-qualified migrations
18pub struct MigrationRunner {
19    pool: Arc<PgPool>,
20    schema_name: String,
21}
22
23impl MigrationRunner {
24    /// Create a new migration runner
25    pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
26        Self { pool, schema_name }
27    }
28
29    fn advisory_lock_key(&self) -> i64 {
30        // Stable 64-bit FNV-1a hash over (namespace + schema name).
31        // This avoids using Rust's DefaultHasher (randomized per-process).
32        const OFFSET: u64 = 0xcbf29ce484222325;
33        const PRIME: u64 = 0x100000001b3;
34
35        let mut hash = OFFSET;
36        for b in b"duroxide_pg:migrations:" {
37            hash ^= *b as u64;
38            hash = hash.wrapping_mul(PRIME);
39        }
40        for b in self.schema_name.as_bytes() {
41            hash ^= *b as u64;
42            hash = hash.wrapping_mul(PRIME);
43        }
44
45        hash as i64
46    }
47
48    async fn lock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
49        let key = self.advisory_lock_key();
50        // Session lock (not xact lock) so it spans multiple transactions.
51        // We explicitly unlock at the end.
52        sqlx::query("SELECT pg_advisory_lock($1)")
53            .bind(key)
54            .execute(&mut *conn)
55            .await?;
56        Ok(())
57    }
58
59    async fn unlock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) {
60        let key = self.advisory_lock_key();
61        // Best-effort unlock.
62        let _ = sqlx::query("SELECT pg_advisory_unlock($1)")
63            .bind(key)
64            .execute(&mut *conn)
65            .await;
66    }
67
68    /// Run all pending migrations
69    pub async fn migrate(&self) -> Result<()> {
70        let mut conn = self.pool.acquire().await?;
71        let conn = &mut *conn;
72        self.lock_for_migrations(conn).await?;
73
74        let result = self.migrate_inner(conn).await;
75        self.unlock_for_migrations(conn).await;
76
77        result
78    }
79
80    async fn migrate_inner(
81        &self,
82        conn: &mut sqlx::postgres::PgConnection,
83    ) -> Result<()> {
84        // Ensure schema exists
85        if self.schema_name != "public" {
86            sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
87                .execute(&mut *conn)
88                .await?;
89        }
90
91        // Load migrations from filesystem
92        let migrations = self.load_migrations()?;
93
94        tracing::debug!(
95            "Loaded {} migrations for schema {}",
96            migrations.len(),
97            self.schema_name
98        );
99
100        // Ensure migration tracking table exists (in the schema)
101        self.ensure_migration_table(conn).await?;
102
103        // Get applied migrations
104        let applied_versions = self.get_applied_versions(conn).await?;
105
106        tracing::debug!("Applied migrations: {:?}", applied_versions);
107
108        // Check if key tables exist - if not, we need to re-run migrations even if marked as applied
109        // This handles the case where cleanup dropped tables but not the migration tracking table
110        let tables_exist = self.check_tables_exist(conn).await.unwrap_or(false);
111
112        // Apply pending migrations (or re-apply if tables don't exist)
113        for migration in migrations {
114            let should_apply = if !applied_versions.contains(&migration.version) {
115                true // New migration
116            } else if !tables_exist {
117                // Migration was applied but tables don't exist - re-apply
118                tracing::warn!(
119                    "Migration {} is marked as applied but tables don't exist, re-applying",
120                    migration.version
121                );
122                // Remove the old migration record so we can re-apply
123                sqlx::query(&format!(
124                    "DELETE FROM {}._duroxide_migrations WHERE version = $1",
125                    self.schema_name
126                ))
127                .bind(migration.version)
128                .execute(&mut *conn)
129                .await?;
130                true
131            } else {
132                false // Already applied and tables exist
133            };
134
135            if should_apply {
136                tracing::debug!(
137                    "Applying migration {}: {}",
138                    migration.version,
139                    migration.name
140                );
141                self.apply_migration(conn, &migration).await?;
142            } else {
143                tracing::debug!(
144                    "Skipping migration {}: {} (already applied)",
145                    migration.version,
146                    migration.name
147                );
148            }
149        }
150
151        Ok(())
152    }
153
154    /// Load migrations from the embedded migrations directory
155    fn load_migrations(&self) -> Result<Vec<Migration>> {
156        let mut migrations = Vec::new();
157
158        // Get all files from embedded directory
159        let mut files: Vec<_> = MIGRATIONS
160            .files()
161            .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
162            .collect();
163
164        // Sort by path to ensure consistent ordering
165        files.sort_by_key(|f| f.path());
166
167        for file in files {
168            let file_name = file
169                .path()
170                .file_name()
171                .and_then(|n| n.to_str())
172                .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
173
174            let sql = file
175                .contents_utf8()
176                .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
177                .to_string();
178
179            let version = self.parse_version(file_name)?;
180            let name = file_name.to_string();
181
182            migrations.push(Migration { version, name, sql });
183        }
184
185        Ok(migrations)
186    }
187
188    /// Parse version number from migration filename (e.g., "0001_initial.sql" -> 1)
189    fn parse_version(&self, filename: &str) -> Result<i64> {
190        let version_str = filename
191            .split('_')
192            .next()
193            .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
194
195        version_str
196            .parse::<i64>()
197            .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
198    }
199
200    /// Ensure migration tracking table exists
201    async fn ensure_migration_table(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
202        // Create migration table in the target schema
203        sqlx::query(&format!(
204            r#"
205            CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
206                version BIGINT PRIMARY KEY,
207                name TEXT NOT NULL,
208                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
209            )
210            "#,
211            self.schema_name
212        ))
213        .execute(&mut *conn)
214        .await?;
215
216        Ok(())
217    }
218
219    /// Check if key tables exist
220    async fn check_tables_exist(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<bool> {
221        // Check if instances table exists (as a proxy for all tables)
222        let exists: bool = sqlx::query_scalar(
223            "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
224        )
225        .bind(&self.schema_name)
226        .fetch_one(&mut *conn)
227        .await?;
228
229        Ok(exists)
230    }
231
232    /// Get list of applied migration versions
233    async fn get_applied_versions(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<Vec<i64>> {
234        let versions: Vec<i64> = sqlx::query_scalar(&format!(
235            "SELECT version FROM {}._duroxide_migrations ORDER BY version",
236            self.schema_name
237        ))
238        .fetch_all(&mut *conn)
239        .await?;
240
241        Ok(versions)
242    }
243
244    /// Split SQL into statements, respecting dollar-quoted strings ($$...$$)
245    /// This handles stored procedures and other constructs that use dollar-quoting
246    fn split_sql_statements(sql: &str) -> Vec<String> {
247        let mut statements = Vec::new();
248        let mut current_statement = String::new();
249        let chars: Vec<char> = sql.chars().collect();
250        let mut i = 0;
251        let mut in_dollar_quote = false;
252        let mut dollar_tag: Option<String> = None;
253
254        while i < chars.len() {
255            let ch = chars[i];
256
257            if !in_dollar_quote {
258                // Check for start of dollar-quoted string
259                if ch == '$' {
260                    let mut tag = String::new();
261                    tag.push(ch);
262                    i += 1;
263
264                    // Collect the tag (e.g., $$, $tag$, $function$)
265                    while i < chars.len() {
266                        let next_ch = chars[i];
267                        if next_ch == '$' {
268                            tag.push(next_ch);
269                            dollar_tag = Some(tag.clone());
270                            in_dollar_quote = true;
271                            current_statement.push_str(&tag);
272                            i += 1;
273                            break;
274                        } else if next_ch.is_alphanumeric() || next_ch == '_' {
275                            tag.push(next_ch);
276                            i += 1;
277                        } else {
278                            // Not a dollar quote, just a $ character
279                            current_statement.push(ch);
280                            break;
281                        }
282                    }
283                } else if ch == ';' {
284                    // End of statement (only if not in dollar quote)
285                    current_statement.push(ch);
286                    let trimmed = current_statement.trim().to_string();
287                    if !trimmed.is_empty() {
288                        statements.push(trimmed);
289                    }
290                    current_statement.clear();
291                    i += 1;
292                } else {
293                    current_statement.push(ch);
294                    i += 1;
295                }
296            } else {
297                // Inside dollar-quoted string
298                current_statement.push(ch);
299
300                // Check for end of dollar-quoted string
301                if ch == '$' {
302                    let tag = dollar_tag.as_ref().unwrap();
303                    let mut matches = true;
304
305                    // Check if the following characters match the closing tag
306                    for (j, tag_char) in tag.chars().enumerate() {
307                        if j == 0 {
308                            continue; // Skip first $ (we already matched it)
309                        }
310                        if i + j >= chars.len() || chars[i + j] != tag_char {
311                            matches = false;
312                            break;
313                        }
314                    }
315
316                    if matches {
317                        // Found closing tag - consume remaining tag characters
318                        for _ in 0..(tag.len() - 1) {
319                            if i + 1 < chars.len() {
320                                current_statement.push(chars[i + 1]);
321                                i += 1;
322                            }
323                        }
324                        in_dollar_quote = false;
325                        dollar_tag = None;
326                    }
327                }
328                i += 1;
329            }
330        }
331
332        // Add remaining statement if any
333        let trimmed = current_statement.trim().to_string();
334        if !trimmed.is_empty() {
335            statements.push(trimmed);
336        }
337
338        statements
339    }
340
341    /// Apply a single migration
342    async fn apply_migration(&self, conn: &mut sqlx::postgres::PgConnection, migration: &Migration) -> Result<()> {
343        // Start transaction
344        let mut tx = conn.begin().await?;
345
346        // Set search_path for this transaction
347        sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
348            .execute(&mut *tx)
349            .await?;
350
351        // Remove comment lines and split SQL into individual statements
352        let sql = migration.sql.trim();
353        let cleaned_sql: String = sql
354            .lines()
355            .map(|line| {
356                // Remove full-line comments
357                if let Some(idx) = line.find("--") {
358                    // Check if -- is inside a string (simple check)
359                    let before = &line[..idx];
360                    if before.matches('\'').count() % 2 == 0 {
361                        // Even number of quotes means -- is not in a string
362                        line[..idx].trim()
363                    } else {
364                        line
365                    }
366                } else {
367                    line
368                }
369            })
370            .filter(|line| !line.is_empty())
371            .collect::<Vec<_>>()
372            .join("\n");
373
374        // Split by semicolon, but respect dollar-quoted strings ($$...$$)
375        let statements = Self::split_sql_statements(&cleaned_sql);
376
377        tracing::debug!(
378            "Executing {} statements for migration {}",
379            statements.len(),
380            migration.version
381        );
382
383        for (idx, statement) in statements.iter().enumerate() {
384            if !statement.trim().is_empty() {
385                tracing::debug!(
386                    "Executing statement {} of {}: {}...",
387                    idx + 1,
388                    statements.len(),
389                    &statement.chars().take(50).collect::<String>()
390                );
391                sqlx::query(statement)
392                    .execute(&mut *tx)
393                    .await
394                    .map_err(|e| {
395                        anyhow::anyhow!(
396                            "Failed to execute statement {} in migration {}: {}\nStatement: {}",
397                            idx + 1,
398                            migration.version,
399                            e,
400                            statement
401                        )
402                    })?;
403            }
404        }
405
406        // Record migration as applied
407        sqlx::query(&format!(
408            "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
409            self.schema_name
410        ))
411        .bind(migration.version)
412        .bind(&migration.name)
413        .execute(&mut *tx)
414        .await?;
415
416        // Commit transaction
417        tx.commit().await?;
418
419        tracing::info!(
420            "Applied migration {}: {}",
421            migration.version,
422            migration.name
423        );
424
425        Ok(())
426    }
427}