Skip to main content

duroxide_pg/
migrations.rs

1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::Connection;
4use sqlx::PgPool;
5use std::sync::Arc;
6
7static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
8
9/// Migration metadata
10#[derive(Debug)]
11struct Migration {
12    version: i64,
13    name: String,
14    sql: String,
15}
16
17/// Migration runner that handles schema-qualified migrations
18pub struct MigrationRunner {
19    pool: Arc<PgPool>,
20    schema_name: String,
21}
22
23impl MigrationRunner {
24    /// Create a new migration runner
25    pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
26        Self { pool, schema_name }
27    }
28
29    fn advisory_lock_key(&self) -> i64 {
30        // Stable 64-bit FNV-1a hash over (namespace + schema name).
31        // This avoids using Rust's DefaultHasher (randomized per-process).
32        const OFFSET: u64 = 0xcbf29ce484222325;
33        const PRIME: u64 = 0x100000001b3;
34
35        let mut hash = OFFSET;
36        for b in b"duroxide_pg:migrations:" {
37            hash ^= *b as u64;
38            hash = hash.wrapping_mul(PRIME);
39        }
40        for b in self.schema_name.as_bytes() {
41            hash ^= *b as u64;
42            hash = hash.wrapping_mul(PRIME);
43        }
44
45        hash as i64
46    }
47
48    async fn lock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
49        let key = self.advisory_lock_key();
50        // Session lock (not xact lock) so it spans multiple transactions.
51        // We explicitly unlock at the end.
52        sqlx::query("SELECT pg_advisory_lock($1)")
53            .bind(key)
54            .execute(&mut *conn)
55            .await?;
56        Ok(())
57    }
58
59    async fn unlock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) {
60        let key = self.advisory_lock_key();
61        // Best-effort unlock.
62        let _ = sqlx::query("SELECT pg_advisory_unlock($1)")
63            .bind(key)
64            .execute(&mut *conn)
65            .await;
66    }
67
68    /// Run all pending migrations
69    pub async fn migrate(&self) -> Result<()> {
70        let mut conn = self.pool.acquire().await?;
71        let conn = &mut *conn;
72        self.lock_for_migrations(conn).await?;
73
74        let result = self.migrate_inner(conn).await;
75        self.unlock_for_migrations(conn).await;
76
77        result
78    }
79
80    async fn migrate_inner(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
81        // Ensure schema exists
82        if self.schema_name != "public" {
83            sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
84                .execute(&mut *conn)
85                .await?;
86        }
87
88        // Load migrations from filesystem
89        let migrations = self.load_migrations()?;
90
91        tracing::debug!(
92            "Loaded {} migrations for schema {}",
93            migrations.len(),
94            self.schema_name
95        );
96
97        // Ensure migration tracking table exists (in the schema)
98        self.ensure_migration_table(conn).await?;
99
100        // Get applied migrations
101        let applied_versions = self.get_applied_versions(conn).await?;
102
103        tracing::debug!("Applied migrations: {:?}", applied_versions);
104
105        // Check if key tables exist - if not, we need to re-run migrations even if marked as applied
106        // This handles the case where cleanup dropped tables but not the migration tracking table
107        let tables_exist = self.check_tables_exist(conn).await.unwrap_or(false);
108
109        // Apply pending migrations (or re-apply if tables don't exist)
110        for migration in migrations {
111            let should_apply = if !applied_versions.contains(&migration.version) {
112                true // New migration
113            } else if !tables_exist {
114                // Migration was applied but tables don't exist - re-apply
115                tracing::warn!(
116                    "Migration {} is marked as applied but tables don't exist, re-applying",
117                    migration.version
118                );
119                // Remove the old migration record so we can re-apply
120                sqlx::query(&format!(
121                    "DELETE FROM {}._duroxide_migrations WHERE version = $1",
122                    self.schema_name
123                ))
124                .bind(migration.version)
125                .execute(&mut *conn)
126                .await?;
127                true
128            } else {
129                false // Already applied and tables exist
130            };
131
132            if should_apply {
133                tracing::debug!(
134                    "Applying migration {}: {}",
135                    migration.version,
136                    migration.name
137                );
138                self.apply_migration(conn, &migration).await?;
139            } else {
140                tracing::debug!(
141                    "Skipping migration {}: {} (already applied)",
142                    migration.version,
143                    migration.name
144                );
145            }
146        }
147
148        Ok(())
149    }
150
151    /// Load migrations from the embedded migrations directory
152    fn load_migrations(&self) -> Result<Vec<Migration>> {
153        let mut migrations = Vec::new();
154
155        // Get all files from embedded directory
156        let mut files: Vec<_> = MIGRATIONS
157            .files()
158            .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
159            .collect();
160
161        // Sort by path to ensure consistent ordering
162        files.sort_by_key(|f| f.path());
163
164        for file in files {
165            let file_name = file
166                .path()
167                .file_name()
168                .and_then(|n| n.to_str())
169                .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
170
171            let sql = file
172                .contents_utf8()
173                .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
174                .to_string();
175
176            let version = self.parse_version(file_name)?;
177            let name = file_name.to_string();
178
179            migrations.push(Migration { version, name, sql });
180        }
181
182        Ok(migrations)
183    }
184
185    /// Parse version number from migration filename (e.g., "0001_initial.sql" -> 1)
186    fn parse_version(&self, filename: &str) -> Result<i64> {
187        let version_str = filename
188            .split('_')
189            .next()
190            .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
191
192        version_str
193            .parse::<i64>()
194            .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
195    }
196
197    /// Ensure migration tracking table exists
198    async fn ensure_migration_table(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
199        // Create migration table in the target schema
200        sqlx::query(&format!(
201            r#"
202            CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
203                version BIGINT PRIMARY KEY,
204                name TEXT NOT NULL,
205                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
206            )
207            "#,
208            self.schema_name
209        ))
210        .execute(&mut *conn)
211        .await?;
212
213        Ok(())
214    }
215
216    /// Check if key tables exist
217    async fn check_tables_exist(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<bool> {
218        // Check if instances table exists (as a proxy for all tables)
219        let exists: bool = sqlx::query_scalar(
220            "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
221        )
222        .bind(&self.schema_name)
223        .fetch_one(&mut *conn)
224        .await?;
225
226        Ok(exists)
227    }
228
229    /// Get list of applied migration versions
230    async fn get_applied_versions(
231        &self,
232        conn: &mut sqlx::postgres::PgConnection,
233    ) -> Result<Vec<i64>> {
234        let versions: Vec<i64> = sqlx::query_scalar(&format!(
235            "SELECT version FROM {}._duroxide_migrations ORDER BY version",
236            self.schema_name
237        ))
238        .fetch_all(&mut *conn)
239        .await?;
240
241        Ok(versions)
242    }
243
244    /// Split SQL into statements, respecting dollar-quoted strings ($$...$$)
245    /// This handles stored procedures and other constructs that use dollar-quoting
246    fn split_sql_statements(sql: &str) -> Vec<String> {
247        let mut statements = Vec::new();
248        let mut current_statement = String::new();
249        let chars: Vec<char> = sql.chars().collect();
250        let mut i = 0;
251        let mut in_dollar_quote = false;
252        let mut dollar_tag: Option<String> = None;
253
254        while i < chars.len() {
255            let ch = chars[i];
256
257            if !in_dollar_quote {
258                // Check for start of dollar-quoted string
259                if ch == '$' {
260                    let mut tag = String::new();
261                    tag.push(ch);
262                    i += 1;
263
264                    // Collect the tag (e.g., $$, $tag$, $function$)
265                    while i < chars.len() {
266                        let next_ch = chars[i];
267                        if next_ch == '$' {
268                            tag.push(next_ch);
269                            dollar_tag = Some(tag.clone());
270                            in_dollar_quote = true;
271                            current_statement.push_str(&tag);
272                            i += 1;
273                            break;
274                        } else if next_ch.is_alphanumeric() || next_ch == '_' {
275                            tag.push(next_ch);
276                            i += 1;
277                        } else {
278                            // Not a dollar quote, just a $ character
279                            current_statement.push(ch);
280                            break;
281                        }
282                    }
283                } else if ch == ';' {
284                    // End of statement (only if not in dollar quote)
285                    current_statement.push(ch);
286                    let trimmed = current_statement.trim().to_string();
287                    if !trimmed.is_empty() {
288                        statements.push(trimmed);
289                    }
290                    current_statement.clear();
291                    i += 1;
292                } else {
293                    current_statement.push(ch);
294                    i += 1;
295                }
296            } else {
297                // Inside dollar-quoted string
298                current_statement.push(ch);
299
300                // Check for end of dollar-quoted string
301                if ch == '$' {
302                    let tag = dollar_tag.as_ref().unwrap();
303                    let mut matches = true;
304
305                    // Check if the following characters match the closing tag
306                    for (j, tag_char) in tag.chars().enumerate() {
307                        if j == 0 {
308                            continue; // Skip first $ (we already matched it)
309                        }
310                        if i + j >= chars.len() || chars[i + j] != tag_char {
311                            matches = false;
312                            break;
313                        }
314                    }
315
316                    if matches {
317                        // Found closing tag - consume remaining tag characters
318                        for _ in 0..(tag.len() - 1) {
319                            if i + 1 < chars.len() {
320                                current_statement.push(chars[i + 1]);
321                                i += 1;
322                            }
323                        }
324                        in_dollar_quote = false;
325                        dollar_tag = None;
326                    }
327                }
328                i += 1;
329            }
330        }
331
332        // Add remaining statement if any
333        let trimmed = current_statement.trim().to_string();
334        if !trimmed.is_empty() {
335            statements.push(trimmed);
336        }
337
338        statements
339    }
340
341    /// Apply a single migration
342    async fn apply_migration(
343        &self,
344        conn: &mut sqlx::postgres::PgConnection,
345        migration: &Migration,
346    ) -> Result<()> {
347        // Start transaction
348        let mut tx = conn.begin().await?;
349
350        // Set search_path for this transaction
351        sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
352            .execute(&mut *tx)
353            .await?;
354
355        // Remove comment lines and split SQL into individual statements
356        let sql = migration.sql.trim();
357        let cleaned_sql: String = sql
358            .lines()
359            .map(|line| {
360                // Remove full-line comments
361                if let Some(idx) = line.find("--") {
362                    // Check if -- is inside a string (simple check)
363                    let before = &line[..idx];
364                    if before.matches('\'').count() % 2 == 0 {
365                        // Even number of quotes means -- is not in a string
366                        line[..idx].trim()
367                    } else {
368                        line
369                    }
370                } else {
371                    line
372                }
373            })
374            .filter(|line| !line.is_empty())
375            .collect::<Vec<_>>()
376            .join("\n");
377
378        // Split by semicolon, but respect dollar-quoted strings ($$...$$)
379        let statements = Self::split_sql_statements(&cleaned_sql);
380
381        tracing::debug!(
382            "Executing {} statements for migration {}",
383            statements.len(),
384            migration.version
385        );
386
387        for (idx, statement) in statements.iter().enumerate() {
388            if !statement.trim().is_empty() {
389                tracing::debug!(
390                    "Executing statement {} of {}: {}...",
391                    idx + 1,
392                    statements.len(),
393                    &statement.chars().take(50).collect::<String>()
394                );
395                sqlx::query(statement)
396                    .execute(&mut *tx)
397                    .await
398                    .map_err(|e| {
399                        anyhow::anyhow!(
400                            "Failed to execute statement {} in migration {}: {}\nStatement: {}",
401                            idx + 1,
402                            migration.version,
403                            e,
404                            statement
405                        )
406                    })?;
407            }
408        }
409
410        // Record migration as applied
411        sqlx::query(&format!(
412            "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
413            self.schema_name
414        ))
415        .bind(migration.version)
416        .bind(&migration.name)
417        .execute(&mut *tx)
418        .await?;
419
420        // Commit transaction
421        tx.commit().await?;
422
423        tracing::info!(
424            "Applied migration {}: {}",
425            migration.version,
426            migration.name
427        );
428
429        Ok(())
430    }
431}