duroxide_pg/
migrations.rs1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::Connection;
4use sqlx::PgPool;
5use std::sync::Arc;
6
7static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
8
9#[derive(Debug)]
11struct Migration {
12 version: i64,
13 name: String,
14 sql: String,
15}
16
17pub struct MigrationRunner {
19 pool: Arc<PgPool>,
20 schema_name: String,
21}
22
23impl MigrationRunner {
24 pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
26 Self { pool, schema_name }
27 }
28
29 fn advisory_lock_key(&self) -> i64 {
30 const OFFSET: u64 = 0xcbf29ce484222325;
33 const PRIME: u64 = 0x100000001b3;
34
35 let mut hash = OFFSET;
36 for b in b"duroxide_pg:migrations:" {
37 hash ^= *b as u64;
38 hash = hash.wrapping_mul(PRIME);
39 }
40 for b in self.schema_name.as_bytes() {
41 hash ^= *b as u64;
42 hash = hash.wrapping_mul(PRIME);
43 }
44
45 hash as i64
46 }
47
48 async fn lock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
49 let key = self.advisory_lock_key();
50 sqlx::query("SELECT pg_advisory_lock($1)")
53 .bind(key)
54 .execute(&mut *conn)
55 .await?;
56 Ok(())
57 }
58
59 async fn unlock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) {
60 let key = self.advisory_lock_key();
61 let _ = sqlx::query("SELECT pg_advisory_unlock($1)")
63 .bind(key)
64 .execute(&mut *conn)
65 .await;
66 }
67
68 pub async fn migrate(&self) -> Result<()> {
70 let mut conn = self.pool.acquire().await?;
71 let conn = &mut *conn;
72 self.lock_for_migrations(conn).await?;
73
74 let result = self.migrate_inner(conn).await;
75 self.unlock_for_migrations(conn).await;
76
77 result
78 }
79
80 async fn migrate_inner(
81 &self,
82 conn: &mut sqlx::postgres::PgConnection,
83 ) -> Result<()> {
84 if self.schema_name != "public" {
86 sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
87 .execute(&mut *conn)
88 .await?;
89 }
90
91 let migrations = self.load_migrations()?;
93
94 tracing::debug!(
95 "Loaded {} migrations for schema {}",
96 migrations.len(),
97 self.schema_name
98 );
99
100 self.ensure_migration_table(conn).await?;
102
103 let applied_versions = self.get_applied_versions(conn).await?;
105
106 tracing::debug!("Applied migrations: {:?}", applied_versions);
107
108 let tables_exist = self.check_tables_exist(conn).await.unwrap_or(false);
111
112 for migration in migrations {
114 let should_apply = if !applied_versions.contains(&migration.version) {
115 true } else if !tables_exist {
117 tracing::warn!(
119 "Migration {} is marked as applied but tables don't exist, re-applying",
120 migration.version
121 );
122 sqlx::query(&format!(
124 "DELETE FROM {}._duroxide_migrations WHERE version = $1",
125 self.schema_name
126 ))
127 .bind(migration.version)
128 .execute(&mut *conn)
129 .await?;
130 true
131 } else {
132 false };
134
135 if should_apply {
136 tracing::debug!(
137 "Applying migration {}: {}",
138 migration.version,
139 migration.name
140 );
141 self.apply_migration(conn, &migration).await?;
142 } else {
143 tracing::debug!(
144 "Skipping migration {}: {} (already applied)",
145 migration.version,
146 migration.name
147 );
148 }
149 }
150
151 Ok(())
152 }
153
154 fn load_migrations(&self) -> Result<Vec<Migration>> {
156 let mut migrations = Vec::new();
157
158 let mut files: Vec<_> = MIGRATIONS
160 .files()
161 .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
162 .collect();
163
164 files.sort_by_key(|f| f.path());
166
167 for file in files {
168 let file_name = file
169 .path()
170 .file_name()
171 .and_then(|n| n.to_str())
172 .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
173
174 let sql = file
175 .contents_utf8()
176 .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
177 .to_string();
178
179 let version = self.parse_version(file_name)?;
180 let name = file_name.to_string();
181
182 migrations.push(Migration { version, name, sql });
183 }
184
185 Ok(migrations)
186 }
187
188 fn parse_version(&self, filename: &str) -> Result<i64> {
190 let version_str = filename
191 .split('_')
192 .next()
193 .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
194
195 version_str
196 .parse::<i64>()
197 .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
198 }
199
200 async fn ensure_migration_table(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
202 sqlx::query(&format!(
204 r#"
205 CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
206 version BIGINT PRIMARY KEY,
207 name TEXT NOT NULL,
208 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
209 )
210 "#,
211 self.schema_name
212 ))
213 .execute(&mut *conn)
214 .await?;
215
216 Ok(())
217 }
218
219 async fn check_tables_exist(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<bool> {
221 let exists: bool = sqlx::query_scalar(
223 "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
224 )
225 .bind(&self.schema_name)
226 .fetch_one(&mut *conn)
227 .await?;
228
229 Ok(exists)
230 }
231
232 async fn get_applied_versions(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<Vec<i64>> {
234 let versions: Vec<i64> = sqlx::query_scalar(&format!(
235 "SELECT version FROM {}._duroxide_migrations ORDER BY version",
236 self.schema_name
237 ))
238 .fetch_all(&mut *conn)
239 .await?;
240
241 Ok(versions)
242 }
243
244 fn split_sql_statements(sql: &str) -> Vec<String> {
247 let mut statements = Vec::new();
248 let mut current_statement = String::new();
249 let chars: Vec<char> = sql.chars().collect();
250 let mut i = 0;
251 let mut in_dollar_quote = false;
252 let mut dollar_tag: Option<String> = None;
253
254 while i < chars.len() {
255 let ch = chars[i];
256
257 if !in_dollar_quote {
258 if ch == '$' {
260 let mut tag = String::new();
261 tag.push(ch);
262 i += 1;
263
264 while i < chars.len() {
266 let next_ch = chars[i];
267 if next_ch == '$' {
268 tag.push(next_ch);
269 dollar_tag = Some(tag.clone());
270 in_dollar_quote = true;
271 current_statement.push_str(&tag);
272 i += 1;
273 break;
274 } else if next_ch.is_alphanumeric() || next_ch == '_' {
275 tag.push(next_ch);
276 i += 1;
277 } else {
278 current_statement.push(ch);
280 break;
281 }
282 }
283 } else if ch == ';' {
284 current_statement.push(ch);
286 let trimmed = current_statement.trim().to_string();
287 if !trimmed.is_empty() {
288 statements.push(trimmed);
289 }
290 current_statement.clear();
291 i += 1;
292 } else {
293 current_statement.push(ch);
294 i += 1;
295 }
296 } else {
297 current_statement.push(ch);
299
300 if ch == '$' {
302 let tag = dollar_tag.as_ref().unwrap();
303 let mut matches = true;
304
305 for (j, tag_char) in tag.chars().enumerate() {
307 if j == 0 {
308 continue; }
310 if i + j >= chars.len() || chars[i + j] != tag_char {
311 matches = false;
312 break;
313 }
314 }
315
316 if matches {
317 for _ in 0..(tag.len() - 1) {
319 if i + 1 < chars.len() {
320 current_statement.push(chars[i + 1]);
321 i += 1;
322 }
323 }
324 in_dollar_quote = false;
325 dollar_tag = None;
326 }
327 }
328 i += 1;
329 }
330 }
331
332 let trimmed = current_statement.trim().to_string();
334 if !trimmed.is_empty() {
335 statements.push(trimmed);
336 }
337
338 statements
339 }
340
341 async fn apply_migration(&self, conn: &mut sqlx::postgres::PgConnection, migration: &Migration) -> Result<()> {
343 let mut tx = conn.begin().await?;
345
346 sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
348 .execute(&mut *tx)
349 .await?;
350
351 let sql = migration.sql.trim();
353 let cleaned_sql: String = sql
354 .lines()
355 .map(|line| {
356 if let Some(idx) = line.find("--") {
358 let before = &line[..idx];
360 if before.matches('\'').count() % 2 == 0 {
361 line[..idx].trim()
363 } else {
364 line
365 }
366 } else {
367 line
368 }
369 })
370 .filter(|line| !line.is_empty())
371 .collect::<Vec<_>>()
372 .join("\n");
373
374 let statements = Self::split_sql_statements(&cleaned_sql);
376
377 tracing::debug!(
378 "Executing {} statements for migration {}",
379 statements.len(),
380 migration.version
381 );
382
383 for (idx, statement) in statements.iter().enumerate() {
384 if !statement.trim().is_empty() {
385 tracing::debug!(
386 "Executing statement {} of {}: {}...",
387 idx + 1,
388 statements.len(),
389 &statement.chars().take(50).collect::<String>()
390 );
391 sqlx::query(statement)
392 .execute(&mut *tx)
393 .await
394 .map_err(|e| {
395 anyhow::anyhow!(
396 "Failed to execute statement {} in migration {}: {}\nStatement: {}",
397 idx + 1,
398 migration.version,
399 e,
400 statement
401 )
402 })?;
403 }
404 }
405
406 sqlx::query(&format!(
408 "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
409 self.schema_name
410 ))
411 .bind(migration.version)
412 .bind(&migration.name)
413 .execute(&mut *tx)
414 .await?;
415
416 tx.commit().await?;
418
419 tracing::info!(
420 "Applied migration {}: {}",
421 migration.version,
422 migration.name
423 );
424
425 Ok(())
426 }
427}