duroxide_pg/
migrations.rs

1use anyhow::Result;
2use sqlx::PgPool;
3use std::sync::Arc;
4use include_dir::{include_dir, Dir};
5
6static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
7
8/// Migration metadata
9#[derive(Debug)]
10struct Migration {
11    version: i64,
12    name: String,
13    sql: String,
14}
15
16/// Migration runner that handles schema-qualified migrations
17pub struct MigrationRunner {
18    pool: Arc<PgPool>,
19    schema_name: String,
20}
21
22impl MigrationRunner {
23    /// Create a new migration runner
24    pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
25        Self { pool, schema_name }
26    }
27
28    /// Run all pending migrations
29    pub async fn migrate(&self) -> Result<()> {
30        // Ensure schema exists
31        if self.schema_name != "public" {
32            sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
33                .execute(&*self.pool)
34                .await?;
35        }
36
37        // Load migrations from filesystem
38        let migrations = self.load_migrations()?;
39
40        tracing::debug!(
41            "Loaded {} migrations for schema {}",
42            migrations.len(),
43            self.schema_name
44        );
45
46        // Ensure migration tracking table exists (in the schema)
47        self.ensure_migration_table().await?;
48
49        // Get applied migrations
50        let applied_versions = self.get_applied_versions().await?;
51
52        tracing::debug!("Applied migrations: {:?}", applied_versions);
53
54        // Check if key tables exist - if not, we need to re-run migrations even if marked as applied
55        // This handles the case where cleanup dropped tables but not the migration tracking table
56        let tables_exist = self.check_tables_exist().await.unwrap_or(false);
57
58        // Apply pending migrations (or re-apply if tables don't exist)
59        for migration in migrations {
60            let should_apply = if !applied_versions.contains(&migration.version) {
61                true // New migration
62            } else if !tables_exist {
63                // Migration was applied but tables don't exist - re-apply
64                tracing::warn!(
65                    "Migration {} is marked as applied but tables don't exist, re-applying",
66                    migration.version
67                );
68                // Remove the old migration record so we can re-apply
69                sqlx::query(&format!(
70                    "DELETE FROM {}._duroxide_migrations WHERE version = $1",
71                    self.schema_name
72                ))
73                .bind(migration.version)
74                .execute(&*self.pool)
75                .await?;
76                true
77            } else {
78                false // Already applied and tables exist
79            };
80
81            if should_apply {
82                tracing::debug!(
83                    "Applying migration {}: {}",
84                    migration.version,
85                    migration.name
86                );
87                self.apply_migration(&migration).await?;
88            } else {
89                tracing::debug!(
90                    "Skipping migration {}: {} (already applied)",
91                    migration.version,
92                    migration.name
93                );
94            }
95        }
96
97        Ok(())
98    }
99
100    /// Load migrations from the embedded migrations directory
101    fn load_migrations(&self) -> Result<Vec<Migration>> {
102        let mut migrations = Vec::new();
103
104        // Get all files from embedded directory
105        let mut files: Vec<_> = MIGRATIONS
106            .files()
107            .filter(|file| {
108                file.path()
109                    .extension()
110                    .and_then(|ext| ext.to_str())
111                    == Some("sql")
112            })
113            .collect();
114
115        // Sort by path to ensure consistent ordering
116        files.sort_by_key(|f| f.path());
117
118        for file in files {
119            let file_name = file
120                .path()
121                .file_name()
122                .and_then(|n| n.to_str())
123                .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
124
125            let sql = file
126                .contents_utf8()
127                .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {}", file_name))?
128                .to_string();
129
130            let version = self.parse_version(file_name)?;
131            let name = file_name.to_string();
132
133            migrations.push(Migration { version, name, sql });
134        }
135
136        Ok(migrations)
137    }
138
139    /// Parse version number from migration filename (e.g., "0001_initial.sql" -> 1)
140    fn parse_version(&self, filename: &str) -> Result<i64> {
141        let version_str = filename
142            .split('_')
143            .next()
144            .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {}", filename))?;
145
146        version_str
147            .parse::<i64>()
148            .map_err(|e| anyhow::anyhow!("Invalid migration version {}: {}", version_str, e))
149    }
150
151    /// Ensure migration tracking table exists
152    async fn ensure_migration_table(&self) -> Result<()> {
153        // Create migration table in the target schema
154        sqlx::query(&format!(
155            r#"
156            CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
157                version BIGINT PRIMARY KEY,
158                name TEXT NOT NULL,
159                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
160            )
161            "#,
162            self.schema_name
163        ))
164        .execute(&*self.pool)
165        .await?;
166
167        Ok(())
168    }
169
170    /// Check if key tables exist
171    async fn check_tables_exist(&self) -> Result<bool> {
172        // Check if instances table exists (as a proxy for all tables)
173        let exists: bool = sqlx::query_scalar(&format!(
174            "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
175        ))
176        .bind(&self.schema_name)
177        .fetch_one(&*self.pool)
178        .await?;
179
180        Ok(exists)
181    }
182
183    /// Get list of applied migration versions
184    async fn get_applied_versions(&self) -> Result<Vec<i64>> {
185        let versions: Vec<i64> = sqlx::query_scalar(&format!(
186            "SELECT version FROM {}._duroxide_migrations ORDER BY version",
187            self.schema_name
188        ))
189        .fetch_all(&*self.pool)
190        .await?;
191
192        Ok(versions)
193    }
194
195    /// Split SQL into statements, respecting dollar-quoted strings ($$...$$)
196    /// This handles stored procedures and other constructs that use dollar-quoting
197    fn split_sql_statements(sql: &str) -> Vec<String> {
198        let mut statements = Vec::new();
199        let mut current_statement = String::new();
200        let chars: Vec<char> = sql.chars().collect();
201        let mut i = 0;
202        let mut in_dollar_quote = false;
203        let mut dollar_tag: Option<String> = None;
204
205        while i < chars.len() {
206            let ch = chars[i];
207
208            if !in_dollar_quote {
209                // Check for start of dollar-quoted string
210                if ch == '$' {
211                    let mut tag = String::new();
212                    tag.push(ch);
213                    i += 1;
214
215                    // Collect the tag (e.g., $$, $tag$, $function$)
216                    while i < chars.len() {
217                        let next_ch = chars[i];
218                        if next_ch == '$' {
219                            tag.push(next_ch);
220                            dollar_tag = Some(tag.clone());
221                            in_dollar_quote = true;
222                            current_statement.push_str(&tag);
223                            i += 1;
224                            break;
225                        } else if next_ch.is_alphanumeric() || next_ch == '_' {
226                            tag.push(next_ch);
227                            i += 1;
228                        } else {
229                            // Not a dollar quote, just a $ character
230                            current_statement.push(ch);
231                            break;
232                        }
233                    }
234                } else if ch == ';' {
235                    // End of statement (only if not in dollar quote)
236                    current_statement.push(ch);
237                    let trimmed = current_statement.trim().to_string();
238                    if !trimmed.is_empty() {
239                        statements.push(trimmed);
240                    }
241                    current_statement.clear();
242                    i += 1;
243                } else {
244                    current_statement.push(ch);
245                    i += 1;
246                }
247            } else {
248                // Inside dollar-quoted string
249                current_statement.push(ch);
250
251                // Check for end of dollar-quoted string
252                if ch == '$' {
253                    let tag = dollar_tag.as_ref().unwrap();
254                    let mut matches = true;
255
256                    // Check if the following characters match the closing tag
257                    for (j, tag_char) in tag.chars().enumerate() {
258                        if j == 0 {
259                            continue; // Skip first $ (we already matched it)
260                        }
261                        if i + j >= chars.len() || chars[i + j] != tag_char {
262                            matches = false;
263                            break;
264                        }
265                    }
266
267                    if matches {
268                        // Found closing tag - consume remaining tag characters
269                        for _ in 0..(tag.len() - 1) {
270                            if i + 1 < chars.len() {
271                                current_statement.push(chars[i + 1]);
272                                i += 1;
273                            }
274                        }
275                        in_dollar_quote = false;
276                        dollar_tag = None;
277                    }
278                }
279                i += 1;
280            }
281        }
282
283        // Add remaining statement if any
284        let trimmed = current_statement.trim().to_string();
285        if !trimmed.is_empty() {
286            statements.push(trimmed);
287        }
288
289        statements
290    }
291
292    /// Apply a single migration
293    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
294        // Start transaction
295        let mut tx = self.pool.begin().await?;
296
297        // Set search_path for this transaction
298        sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
299            .execute(&mut *tx)
300            .await?;
301
302        // Remove comment lines and split SQL into individual statements
303        let sql = migration.sql.trim();
304        let cleaned_sql: String = sql
305            .lines()
306            .map(|line| {
307                // Remove full-line comments
308                if let Some(idx) = line.find("--") {
309                    // Check if -- is inside a string (simple check)
310                    let before = &line[..idx];
311                    if before.matches('\'').count() % 2 == 0 {
312                        // Even number of quotes means -- is not in a string
313                        line[..idx].trim()
314                    } else {
315                        line
316                    }
317                } else {
318                    line
319                }
320            })
321            .filter(|line| !line.is_empty())
322            .collect::<Vec<_>>()
323            .join("\n");
324
325        // Split by semicolon, but respect dollar-quoted strings ($$...$$)
326        let statements = Self::split_sql_statements(&cleaned_sql);
327
328        tracing::debug!(
329            "Executing {} statements for migration {}",
330            statements.len(),
331            migration.version
332        );
333
334        for (idx, statement) in statements.iter().enumerate() {
335            if !statement.trim().is_empty() {
336                tracing::debug!(
337                    "Executing statement {} of {}: {}...",
338                    idx + 1,
339                    statements.len(),
340                    &statement.chars().take(50).collect::<String>()
341                );
342                sqlx::query(statement)
343                    .execute(&mut *tx)
344                    .await
345                    .map_err(|e| {
346                        anyhow::anyhow!(
347                            "Failed to execute statement {} in migration {}: {}\nStatement: {}",
348                            idx + 1,
349                            migration.version,
350                            e,
351                            statement
352                        )
353                    })?;
354            }
355        }
356
357        // Record migration as applied
358        sqlx::query(&format!(
359            "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
360            self.schema_name
361        ))
362        .bind(migration.version)
363        .bind(&migration.name)
364        .execute(&mut *tx)
365        .await?;
366
367        // Commit transaction
368        tx.commit().await?;
369
370        tracing::info!(
371            "Applied migration {}: {}",
372            migration.version,
373            migration.name
374        );
375
376        Ok(())
377    }
378}