Skip to main content

duroxide_pg/
migrations.rs

1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::PgPool;
4use std::sync::Arc;
5
6static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
7
8/// Migration metadata
9#[derive(Debug)]
10struct Migration {
11    version: i64,
12    name: String,
13    sql: String,
14}
15
16/// Migration runner that handles schema-qualified migrations
17pub struct MigrationRunner {
18    pool: Arc<PgPool>,
19    schema_name: String,
20}
21
22impl MigrationRunner {
23    /// Create a new migration runner
24    pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
25        Self { pool, schema_name }
26    }
27
28    /// Run all pending migrations
29    pub async fn migrate(&self) -> Result<()> {
30        // Ensure schema exists
31        if self.schema_name != "public" {
32            sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
33                .execute(&*self.pool)
34                .await?;
35        }
36
37        // Load migrations from filesystem
38        let migrations = self.load_migrations()?;
39
40        tracing::debug!(
41            "Loaded {} migrations for schema {}",
42            migrations.len(),
43            self.schema_name
44        );
45
46        // Ensure migration tracking table exists (in the schema)
47        self.ensure_migration_table().await?;
48
49        // Get applied migrations
50        let applied_versions = self.get_applied_versions().await?;
51
52        tracing::debug!("Applied migrations: {:?}", applied_versions);
53
54        // Check if key tables exist - if not, we need to re-run migrations even if marked as applied
55        // This handles the case where cleanup dropped tables but not the migration tracking table
56        let tables_exist = self.check_tables_exist().await.unwrap_or(false);
57
58        // Apply pending migrations (or re-apply if tables don't exist)
59        for migration in migrations {
60            let should_apply = if !applied_versions.contains(&migration.version) {
61                true // New migration
62            } else if !tables_exist {
63                // Migration was applied but tables don't exist - re-apply
64                tracing::warn!(
65                    "Migration {} is marked as applied but tables don't exist, re-applying",
66                    migration.version
67                );
68                // Remove the old migration record so we can re-apply
69                sqlx::query(&format!(
70                    "DELETE FROM {}._duroxide_migrations WHERE version = $1",
71                    self.schema_name
72                ))
73                .bind(migration.version)
74                .execute(&*self.pool)
75                .await?;
76                true
77            } else {
78                false // Already applied and tables exist
79            };
80
81            if should_apply {
82                tracing::debug!(
83                    "Applying migration {}: {}",
84                    migration.version,
85                    migration.name
86                );
87                self.apply_migration(&migration).await?;
88            } else {
89                tracing::debug!(
90                    "Skipping migration {}: {} (already applied)",
91                    migration.version,
92                    migration.name
93                );
94            }
95        }
96
97        Ok(())
98    }
99
100    /// Load migrations from the embedded migrations directory
101    fn load_migrations(&self) -> Result<Vec<Migration>> {
102        let mut migrations = Vec::new();
103
104        // Get all files from embedded directory
105        let mut files: Vec<_> = MIGRATIONS
106            .files()
107            .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
108            .collect();
109
110        // Sort by path to ensure consistent ordering
111        files.sort_by_key(|f| f.path());
112
113        for file in files {
114            let file_name = file
115                .path()
116                .file_name()
117                .and_then(|n| n.to_str())
118                .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
119
120            let sql = file
121                .contents_utf8()
122                .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
123                .to_string();
124
125            let version = self.parse_version(file_name)?;
126            let name = file_name.to_string();
127
128            migrations.push(Migration { version, name, sql });
129        }
130
131        Ok(migrations)
132    }
133
134    /// Parse version number from migration filename (e.g., "0001_initial.sql" -> 1)
135    fn parse_version(&self, filename: &str) -> Result<i64> {
136        let version_str = filename
137            .split('_')
138            .next()
139            .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
140
141        version_str
142            .parse::<i64>()
143            .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
144    }
145
146    /// Ensure migration tracking table exists
147    async fn ensure_migration_table(&self) -> Result<()> {
148        // Create migration table in the target schema
149        sqlx::query(&format!(
150            r#"
151            CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
152                version BIGINT PRIMARY KEY,
153                name TEXT NOT NULL,
154                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
155            )
156            "#,
157            self.schema_name
158        ))
159        .execute(&*self.pool)
160        .await?;
161
162        Ok(())
163    }
164
165    /// Check if key tables exist
166    async fn check_tables_exist(&self) -> Result<bool> {
167        // Check if instances table exists (as a proxy for all tables)
168        let exists: bool = sqlx::query_scalar(
169            "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
170        )
171        .bind(&self.schema_name)
172        .fetch_one(&*self.pool)
173        .await?;
174
175        Ok(exists)
176    }
177
178    /// Get list of applied migration versions
179    async fn get_applied_versions(&self) -> Result<Vec<i64>> {
180        let versions: Vec<i64> = sqlx::query_scalar(&format!(
181            "SELECT version FROM {}._duroxide_migrations ORDER BY version",
182            self.schema_name
183        ))
184        .fetch_all(&*self.pool)
185        .await?;
186
187        Ok(versions)
188    }
189
190    /// Split SQL into statements, respecting dollar-quoted strings ($$...$$)
191    /// This handles stored procedures and other constructs that use dollar-quoting
192    fn split_sql_statements(sql: &str) -> Vec<String> {
193        let mut statements = Vec::new();
194        let mut current_statement = String::new();
195        let chars: Vec<char> = sql.chars().collect();
196        let mut i = 0;
197        let mut in_dollar_quote = false;
198        let mut dollar_tag: Option<String> = None;
199
200        while i < chars.len() {
201            let ch = chars[i];
202
203            if !in_dollar_quote {
204                // Check for start of dollar-quoted string
205                if ch == '$' {
206                    let mut tag = String::new();
207                    tag.push(ch);
208                    i += 1;
209
210                    // Collect the tag (e.g., $$, $tag$, $function$)
211                    while i < chars.len() {
212                        let next_ch = chars[i];
213                        if next_ch == '$' {
214                            tag.push(next_ch);
215                            dollar_tag = Some(tag.clone());
216                            in_dollar_quote = true;
217                            current_statement.push_str(&tag);
218                            i += 1;
219                            break;
220                        } else if next_ch.is_alphanumeric() || next_ch == '_' {
221                            tag.push(next_ch);
222                            i += 1;
223                        } else {
224                            // Not a dollar quote, just a $ character
225                            current_statement.push(ch);
226                            break;
227                        }
228                    }
229                } else if ch == ';' {
230                    // End of statement (only if not in dollar quote)
231                    current_statement.push(ch);
232                    let trimmed = current_statement.trim().to_string();
233                    if !trimmed.is_empty() {
234                        statements.push(trimmed);
235                    }
236                    current_statement.clear();
237                    i += 1;
238                } else {
239                    current_statement.push(ch);
240                    i += 1;
241                }
242            } else {
243                // Inside dollar-quoted string
244                current_statement.push(ch);
245
246                // Check for end of dollar-quoted string
247                if ch == '$' {
248                    let tag = dollar_tag.as_ref().unwrap();
249                    let mut matches = true;
250
251                    // Check if the following characters match the closing tag
252                    for (j, tag_char) in tag.chars().enumerate() {
253                        if j == 0 {
254                            continue; // Skip first $ (we already matched it)
255                        }
256                        if i + j >= chars.len() || chars[i + j] != tag_char {
257                            matches = false;
258                            break;
259                        }
260                    }
261
262                    if matches {
263                        // Found closing tag - consume remaining tag characters
264                        for _ in 0..(tag.len() - 1) {
265                            if i + 1 < chars.len() {
266                                current_statement.push(chars[i + 1]);
267                                i += 1;
268                            }
269                        }
270                        in_dollar_quote = false;
271                        dollar_tag = None;
272                    }
273                }
274                i += 1;
275            }
276        }
277
278        // Add remaining statement if any
279        let trimmed = current_statement.trim().to_string();
280        if !trimmed.is_empty() {
281            statements.push(trimmed);
282        }
283
284        statements
285    }
286
287    /// Apply a single migration
288    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
289        // Start transaction
290        let mut tx = self.pool.begin().await?;
291
292        // Set search_path for this transaction
293        sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
294            .execute(&mut *tx)
295            .await?;
296
297        // Remove comment lines and split SQL into individual statements
298        let sql = migration.sql.trim();
299        let cleaned_sql: String = sql
300            .lines()
301            .map(|line| {
302                // Remove full-line comments
303                if let Some(idx) = line.find("--") {
304                    // Check if -- is inside a string (simple check)
305                    let before = &line[..idx];
306                    if before.matches('\'').count() % 2 == 0 {
307                        // Even number of quotes means -- is not in a string
308                        line[..idx].trim()
309                    } else {
310                        line
311                    }
312                } else {
313                    line
314                }
315            })
316            .filter(|line| !line.is_empty())
317            .collect::<Vec<_>>()
318            .join("\n");
319
320        // Split by semicolon, but respect dollar-quoted strings ($$...$$)
321        let statements = Self::split_sql_statements(&cleaned_sql);
322
323        tracing::debug!(
324            "Executing {} statements for migration {}",
325            statements.len(),
326            migration.version
327        );
328
329        for (idx, statement) in statements.iter().enumerate() {
330            if !statement.trim().is_empty() {
331                tracing::debug!(
332                    "Executing statement {} of {}: {}...",
333                    idx + 1,
334                    statements.len(),
335                    &statement.chars().take(50).collect::<String>()
336                );
337                sqlx::query(statement)
338                    .execute(&mut *tx)
339                    .await
340                    .map_err(|e| {
341                        anyhow::anyhow!(
342                            "Failed to execute statement {} in migration {}: {}\nStatement: {}",
343                            idx + 1,
344                            migration.version,
345                            e,
346                            statement
347                        )
348                    })?;
349            }
350        }
351
352        // Record migration as applied
353        sqlx::query(&format!(
354            "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
355            self.schema_name
356        ))
357        .bind(migration.version)
358        .bind(&migration.name)
359        .execute(&mut *tx)
360        .await?;
361
362        // Commit transaction
363        tx.commit().await?;
364
365        tracing::info!(
366            "Applied migration {}: {}",
367            migration.version,
368            migration.name
369        );
370
371        Ok(())
372    }
373}