Skip to main content

duroxide_pg_opt/
migrations.rs

1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::PgPool;
4use std::sync::Arc;
5
6// Embedded SQL migrations. Touch this module when adding a new migration so
7// incremental builds re-embed the directory contents. Latest: 0007.
8static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
9
10/// Migration metadata
11#[derive(Debug)]
12struct Migration {
13    version: i64,
14    name: String,
15    sql: String,
16}
17
18/// Migration runner that handles schema-qualified migrations
19pub struct MigrationRunner {
20    pool: Arc<PgPool>,
21    schema_name: String,
22}
23
24impl MigrationRunner {
25    /// Create a new migration runner
26    pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
27        Self { pool, schema_name }
28    }
29
30    /// Run all pending migrations
31    pub async fn migrate(&self) -> Result<()> {
32        // Ensure schema exists
33        if self.schema_name != "public" {
34            sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
35                .execute(&*self.pool)
36                .await?;
37        }
38
39        // Load migrations from filesystem
40        let migrations = self.load_migrations()?;
41
42        tracing::debug!(
43            "Loaded {} migrations for schema {}",
44            migrations.len(),
45            self.schema_name
46        );
47
48        // Ensure migration tracking table exists (in the schema)
49        self.ensure_migration_table().await?;
50
51        // Get applied migrations
52        let applied_versions = self.get_applied_versions().await?;
53
54        tracing::debug!("Applied migrations: {:?}", applied_versions);
55
56        // Check if key tables exist - if not, we need to re-run migrations even if marked as applied
57        // This handles the case where cleanup dropped tables but not the migration tracking table
58        let tables_exist = self.check_tables_exist().await.unwrap_or(false);
59
60        // Apply pending migrations (or re-apply if tables don't exist)
61        for migration in migrations {
62            let should_apply = if !applied_versions.contains(&migration.version) {
63                true // New migration
64            } else if !tables_exist {
65                // Migration was applied but tables don't exist - re-apply
66                tracing::warn!(
67                    "Migration {} is marked as applied but tables don't exist, re-applying",
68                    migration.version
69                );
70                // Remove the old migration record so we can re-apply
71                sqlx::query(&format!(
72                    "DELETE FROM {}._duroxide_migrations WHERE version = $1",
73                    self.schema_name
74                ))
75                .bind(migration.version)
76                .execute(&*self.pool)
77                .await?;
78                true
79            } else {
80                false // Already applied and tables exist
81            };
82
83            if should_apply {
84                tracing::debug!(
85                    "Applying migration {}: {}",
86                    migration.version,
87                    migration.name
88                );
89                self.apply_migration(&migration).await?;
90            } else {
91                tracing::debug!(
92                    "Skipping migration {}: {} (already applied)",
93                    migration.version,
94                    migration.name
95                );
96            }
97        }
98
99        Ok(())
100    }
101
102    /// Load migrations from the embedded migrations directory
103    fn load_migrations(&self) -> Result<Vec<Migration>> {
104        let mut migrations = Vec::new();
105
106        // Get all files from embedded directory
107        let mut files: Vec<_> = MIGRATIONS
108            .files()
109            .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
110            .collect();
111
112        // Sort by path to ensure consistent ordering
113        files.sort_by_key(|f| f.path());
114
115        for file in files {
116            let file_name = file
117                .path()
118                .file_name()
119                .and_then(|n| n.to_str())
120                .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
121
122            let sql = file
123                .contents_utf8()
124                .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
125                .to_string();
126
127            let version = self.parse_version(file_name)?;
128            let name = file_name.to_string();
129
130            migrations.push(Migration { version, name, sql });
131        }
132
133        Ok(migrations)
134    }
135
136    /// Parse version number from migration filename (e.g., "0001_initial.sql" -> 1)
137    fn parse_version(&self, filename: &str) -> Result<i64> {
138        let version_str = filename
139            .split('_')
140            .next()
141            .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
142
143        version_str
144            .parse::<i64>()
145            .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
146    }
147
148    /// Ensure migration tracking table exists
149    async fn ensure_migration_table(&self) -> Result<()> {
150        // Create migration table in the target schema
151        sqlx::query(&format!(
152            r#"
153            CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
154                version BIGINT PRIMARY KEY,
155                name TEXT NOT NULL,
156                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
157            )
158            "#,
159            self.schema_name
160        ))
161        .execute(&*self.pool)
162        .await?;
163
164        Ok(())
165    }
166
167    /// Check if key tables exist
168    async fn check_tables_exist(&self) -> Result<bool> {
169        // Check if instances table exists (as a proxy for all tables)
170        let exists: bool = sqlx::query_scalar(
171            "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
172        )
173        .bind(&self.schema_name)
174        .fetch_one(&*self.pool)
175        .await?;
176
177        Ok(exists)
178    }
179
180    /// Get list of applied migration versions
181    async fn get_applied_versions(&self) -> Result<Vec<i64>> {
182        let versions: Vec<i64> = sqlx::query_scalar(&format!(
183            "SELECT version FROM {}._duroxide_migrations ORDER BY version",
184            self.schema_name
185        ))
186        .fetch_all(&*self.pool)
187        .await?;
188
189        Ok(versions)
190    }
191
192    /// Split SQL into statements, respecting dollar-quoted strings ($$...$$)
193    /// This handles stored procedures and other constructs that use dollar-quoting
194    fn split_sql_statements(sql: &str) -> Vec<String> {
195        let mut statements = Vec::new();
196        let mut current_statement = String::new();
197        let chars: Vec<char> = sql.chars().collect();
198        let mut i = 0;
199        let mut in_dollar_quote = false;
200        let mut dollar_tag: Option<String> = None;
201
202        while i < chars.len() {
203            let ch = chars[i];
204
205            if !in_dollar_quote {
206                // Check for start of dollar-quoted string
207                if ch == '$' {
208                    let mut tag = String::new();
209                    tag.push(ch);
210                    i += 1;
211
212                    // Collect the tag (e.g., $$, $tag$, $function$)
213                    while i < chars.len() {
214                        let next_ch = chars[i];
215                        if next_ch == '$' {
216                            tag.push(next_ch);
217                            dollar_tag = Some(tag.clone());
218                            in_dollar_quote = true;
219                            current_statement.push_str(&tag);
220                            i += 1;
221                            break;
222                        } else if next_ch.is_alphanumeric() || next_ch == '_' {
223                            tag.push(next_ch);
224                            i += 1;
225                        } else {
226                            // Not a dollar quote, just a $ character
227                            current_statement.push(ch);
228                            break;
229                        }
230                    }
231                } else if ch == ';' {
232                    // End of statement (only if not in dollar quote)
233                    current_statement.push(ch);
234                    let trimmed = current_statement.trim().to_string();
235                    if !trimmed.is_empty() {
236                        statements.push(trimmed);
237                    }
238                    current_statement.clear();
239                    i += 1;
240                } else {
241                    current_statement.push(ch);
242                    i += 1;
243                }
244            } else {
245                // Inside dollar-quoted string
246                current_statement.push(ch);
247
248                // Check for end of dollar-quoted string
249                if ch == '$' {
250                    let tag = dollar_tag.as_ref().unwrap();
251                    let mut matches = true;
252
253                    // Check if the following characters match the closing tag
254                    for (j, tag_char) in tag.chars().enumerate() {
255                        if j == 0 {
256                            continue; // Skip first $ (we already matched it)
257                        }
258                        if i + j >= chars.len() || chars[i + j] != tag_char {
259                            matches = false;
260                            break;
261                        }
262                    }
263
264                    if matches {
265                        // Found closing tag - consume remaining tag characters
266                        for _ in 0..(tag.len() - 1) {
267                            if i + 1 < chars.len() {
268                                current_statement.push(chars[i + 1]);
269                                i += 1;
270                            }
271                        }
272                        in_dollar_quote = false;
273                        dollar_tag = None;
274                    }
275                }
276                i += 1;
277            }
278        }
279
280        // Add remaining statement if any
281        let trimmed = current_statement.trim().to_string();
282        if !trimmed.is_empty() {
283            statements.push(trimmed);
284        }
285
286        statements
287    }
288
289    /// Apply a single migration
290    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
291        // Start transaction
292        let mut tx = self.pool.begin().await?;
293
294        // Set search_path for this transaction
295        sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
296            .execute(&mut *tx)
297            .await?;
298
299        // Remove comment lines and split SQL into individual statements
300        let sql = migration.sql.trim();
301        let cleaned_sql: String = sql
302            .lines()
303            .map(|line| {
304                // Remove full-line comments
305                if let Some(idx) = line.find("--") {
306                    // Check if -- is inside a string (simple check)
307                    let before = &line[..idx];
308                    if before.matches('\'').count() % 2 == 0 {
309                        // Even number of quotes means -- is not in a string
310                        line[..idx].trim()
311                    } else {
312                        line
313                    }
314                } else {
315                    line
316                }
317            })
318            .filter(|line| !line.is_empty())
319            .collect::<Vec<_>>()
320            .join("\n");
321
322        // Split by semicolon, but respect dollar-quoted strings ($$...$$)
323        let statements = Self::split_sql_statements(&cleaned_sql);
324
325        tracing::debug!(
326            "Executing {} statements for migration {}",
327            statements.len(),
328            migration.version
329        );
330
331        for (idx, statement) in statements.iter().enumerate() {
332            if !statement.trim().is_empty() {
333                tracing::debug!(
334                    "Executing statement {} of {}: {}...",
335                    idx + 1,
336                    statements.len(),
337                    &statement.chars().take(50).collect::<String>()
338                );
339                sqlx::query(statement)
340                    .execute(&mut *tx)
341                    .await
342                    .map_err(|e| {
343                        anyhow::anyhow!(
344                            "Failed to execute statement {} in migration {}: {}\nStatement: {}",
345                            idx + 1,
346                            migration.version,
347                            e,
348                            statement
349                        )
350                    })?;
351            }
352        }
353
354        // Record migration as applied
355        sqlx::query(&format!(
356            "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
357            self.schema_name
358        ))
359        .bind(migration.version)
360        .bind(&migration.name)
361        .execute(&mut *tx)
362        .await?;
363
364        // Commit transaction
365        tx.commit().await?;
366
367        tracing::info!(
368            "Applied migration {}: {}",
369            migration.version,
370            migration.name
371        );
372
373        Ok(())
374    }
375}