duroxide_pg/
migrations.rs1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::Connection;
4use sqlx::PgPool;
5use std::sync::Arc;
6
7static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
8
9#[derive(Debug)]
11struct Migration {
12 version: i64,
13 name: String,
14 sql: String,
15}
16
17pub struct MigrationRunner {
19 pool: Arc<PgPool>,
20 schema_name: String,
21}
22
23impl MigrationRunner {
24 pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
26 Self { pool, schema_name }
27 }
28
29 fn advisory_lock_key(&self) -> i64 {
30 const OFFSET: u64 = 0xcbf29ce484222325;
33 const PRIME: u64 = 0x100000001b3;
34
35 let mut hash = OFFSET;
36 for b in b"duroxide_pg:migrations:" {
37 hash ^= *b as u64;
38 hash = hash.wrapping_mul(PRIME);
39 }
40 for b in self.schema_name.as_bytes() {
41 hash ^= *b as u64;
42 hash = hash.wrapping_mul(PRIME);
43 }
44
45 hash as i64
46 }
47
48 async fn lock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
49 let key = self.advisory_lock_key();
50 sqlx::query("SELECT pg_advisory_lock($1)")
53 .bind(key)
54 .execute(&mut *conn)
55 .await?;
56 Ok(())
57 }
58
59 async fn unlock_for_migrations(&self, conn: &mut sqlx::postgres::PgConnection) {
60 let key = self.advisory_lock_key();
61 let _ = sqlx::query("SELECT pg_advisory_unlock($1)")
63 .bind(key)
64 .execute(&mut *conn)
65 .await;
66 }
67
68 pub async fn migrate(&self) -> Result<()> {
70 let mut conn = self.pool.acquire().await?;
71 let conn = &mut *conn;
72 self.lock_for_migrations(conn).await?;
73
74 let result = self.migrate_inner(conn).await;
75 self.unlock_for_migrations(conn).await;
76
77 result
78 }
79
80 async fn migrate_inner(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
81 if self.schema_name != "public" {
83 sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
84 .execute(&mut *conn)
85 .await?;
86 }
87
88 let migrations = self.load_migrations()?;
90
91 tracing::debug!(
92 "Loaded {} migrations for schema {}",
93 migrations.len(),
94 self.schema_name
95 );
96
97 self.ensure_migration_table(conn).await?;
99
100 let applied_versions = self.get_applied_versions(conn).await?;
102
103 tracing::debug!("Applied migrations: {:?}", applied_versions);
104
105 let tables_exist = self.check_tables_exist(conn).await.unwrap_or(false);
108
109 for migration in migrations {
111 let should_apply = if !applied_versions.contains(&migration.version) {
112 true } else if !tables_exist {
114 tracing::warn!(
116 "Migration {} is marked as applied but tables don't exist, re-applying",
117 migration.version
118 );
119 sqlx::query(&format!(
121 "DELETE FROM {}._duroxide_migrations WHERE version = $1",
122 self.schema_name
123 ))
124 .bind(migration.version)
125 .execute(&mut *conn)
126 .await?;
127 true
128 } else {
129 false };
131
132 if should_apply {
133 tracing::debug!(
134 "Applying migration {}: {}",
135 migration.version,
136 migration.name
137 );
138 self.apply_migration(conn, &migration).await?;
139 } else {
140 tracing::debug!(
141 "Skipping migration {}: {} (already applied)",
142 migration.version,
143 migration.name
144 );
145 }
146 }
147
148 Ok(())
149 }
150
151 fn load_migrations(&self) -> Result<Vec<Migration>> {
153 let mut migrations = Vec::new();
154
155 let mut files: Vec<_> = MIGRATIONS
157 .files()
158 .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
159 .collect();
160
161 files.sort_by_key(|f| f.path());
163
164 for file in files {
165 let file_name = file
166 .path()
167 .file_name()
168 .and_then(|n| n.to_str())
169 .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
170
171 let sql = file
172 .contents_utf8()
173 .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
174 .to_string();
175
176 let version = self.parse_version(file_name)?;
177 let name = file_name.to_string();
178
179 migrations.push(Migration { version, name, sql });
180 }
181
182 Ok(migrations)
183 }
184
185 fn parse_version(&self, filename: &str) -> Result<i64> {
187 let version_str = filename
188 .split('_')
189 .next()
190 .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
191
192 version_str
193 .parse::<i64>()
194 .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
195 }
196
197 async fn ensure_migration_table(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<()> {
199 sqlx::query(&format!(
201 r#"
202 CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
203 version BIGINT PRIMARY KEY,
204 name TEXT NOT NULL,
205 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
206 )
207 "#,
208 self.schema_name
209 ))
210 .execute(&mut *conn)
211 .await?;
212
213 Ok(())
214 }
215
216 async fn check_tables_exist(&self, conn: &mut sqlx::postgres::PgConnection) -> Result<bool> {
218 let exists: bool = sqlx::query_scalar(
220 "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
221 )
222 .bind(&self.schema_name)
223 .fetch_one(&mut *conn)
224 .await?;
225
226 Ok(exists)
227 }
228
229 async fn get_applied_versions(
231 &self,
232 conn: &mut sqlx::postgres::PgConnection,
233 ) -> Result<Vec<i64>> {
234 let versions: Vec<i64> = sqlx::query_scalar(&format!(
235 "SELECT version FROM {}._duroxide_migrations ORDER BY version",
236 self.schema_name
237 ))
238 .fetch_all(&mut *conn)
239 .await?;
240
241 Ok(versions)
242 }
243
244 fn split_sql_statements(sql: &str) -> Vec<String> {
247 let mut statements = Vec::new();
248 let mut current_statement = String::new();
249 let chars: Vec<char> = sql.chars().collect();
250 let mut i = 0;
251 let mut in_dollar_quote = false;
252 let mut dollar_tag: Option<String> = None;
253
254 while i < chars.len() {
255 let ch = chars[i];
256
257 if !in_dollar_quote {
258 if ch == '$' {
260 let mut tag = String::new();
261 tag.push(ch);
262 i += 1;
263
264 while i < chars.len() {
266 let next_ch = chars[i];
267 if next_ch == '$' {
268 tag.push(next_ch);
269 dollar_tag = Some(tag.clone());
270 in_dollar_quote = true;
271 current_statement.push_str(&tag);
272 i += 1;
273 break;
274 } else if next_ch.is_alphanumeric() || next_ch == '_' {
275 tag.push(next_ch);
276 i += 1;
277 } else {
278 current_statement.push(ch);
280 break;
281 }
282 }
283 } else if ch == ';' {
284 current_statement.push(ch);
286 let trimmed = current_statement.trim().to_string();
287 if !trimmed.is_empty() {
288 statements.push(trimmed);
289 }
290 current_statement.clear();
291 i += 1;
292 } else {
293 current_statement.push(ch);
294 i += 1;
295 }
296 } else {
297 current_statement.push(ch);
299
300 if ch == '$' {
302 let tag = dollar_tag.as_ref().unwrap();
303 let mut matches = true;
304
305 for (j, tag_char) in tag.chars().enumerate() {
307 if j == 0 {
308 continue; }
310 if i + j >= chars.len() || chars[i + j] != tag_char {
311 matches = false;
312 break;
313 }
314 }
315
316 if matches {
317 for _ in 0..(tag.len() - 1) {
319 if i + 1 < chars.len() {
320 current_statement.push(chars[i + 1]);
321 i += 1;
322 }
323 }
324 in_dollar_quote = false;
325 dollar_tag = None;
326 }
327 }
328 i += 1;
329 }
330 }
331
332 let trimmed = current_statement.trim().to_string();
334 if !trimmed.is_empty() {
335 statements.push(trimmed);
336 }
337
338 statements
339 }
340
341 async fn apply_migration(
343 &self,
344 conn: &mut sqlx::postgres::PgConnection,
345 migration: &Migration,
346 ) -> Result<()> {
347 let mut tx = conn.begin().await?;
349
350 sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
352 .execute(&mut *tx)
353 .await?;
354
355 let sql = migration.sql.trim();
357 let cleaned_sql: String = sql
358 .lines()
359 .map(|line| {
360 if let Some(idx) = line.find("--") {
362 let before = &line[..idx];
364 if before.matches('\'').count() % 2 == 0 {
365 line[..idx].trim()
367 } else {
368 line
369 }
370 } else {
371 line
372 }
373 })
374 .filter(|line| !line.is_empty())
375 .collect::<Vec<_>>()
376 .join("\n");
377
378 let statements = Self::split_sql_statements(&cleaned_sql);
380
381 tracing::debug!(
382 "Executing {} statements for migration {}",
383 statements.len(),
384 migration.version
385 );
386
387 for (idx, statement) in statements.iter().enumerate() {
388 if !statement.trim().is_empty() {
389 tracing::debug!(
390 "Executing statement {} of {}: {}...",
391 idx + 1,
392 statements.len(),
393 &statement.chars().take(50).collect::<String>()
394 );
395 sqlx::query(statement)
396 .execute(&mut *tx)
397 .await
398 .map_err(|e| {
399 anyhow::anyhow!(
400 "Failed to execute statement {} in migration {}: {}\nStatement: {}",
401 idx + 1,
402 migration.version,
403 e,
404 statement
405 )
406 })?;
407 }
408 }
409
410 sqlx::query(&format!(
412 "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
413 self.schema_name
414 ))
415 .bind(migration.version)
416 .bind(&migration.name)
417 .execute(&mut *tx)
418 .await?;
419
420 tx.commit().await?;
422
423 tracing::info!(
424 "Applied migration {}: {}",
425 migration.version,
426 migration.name
427 );
428
429 Ok(())
430 }
431}