duroxide_pg/
migrations.rs1use anyhow::Result;
2use sqlx::PgPool;
3use std::sync::Arc;
4use include_dir::{include_dir, Dir};
5
6static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
7
8#[derive(Debug)]
10struct Migration {
11 version: i64,
12 name: String,
13 sql: String,
14}
15
16pub struct MigrationRunner {
18 pool: Arc<PgPool>,
19 schema_name: String,
20}
21
22impl MigrationRunner {
23 pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
25 Self { pool, schema_name }
26 }
27
28 pub async fn migrate(&self) -> Result<()> {
30 if self.schema_name != "public" {
32 sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
33 .execute(&*self.pool)
34 .await?;
35 }
36
37 let migrations = self.load_migrations()?;
39
40 tracing::debug!(
41 "Loaded {} migrations for schema {}",
42 migrations.len(),
43 self.schema_name
44 );
45
46 self.ensure_migration_table().await?;
48
49 let applied_versions = self.get_applied_versions().await?;
51
52 tracing::debug!("Applied migrations: {:?}", applied_versions);
53
54 let tables_exist = self.check_tables_exist().await.unwrap_or(false);
57
58 for migration in migrations {
60 let should_apply = if !applied_versions.contains(&migration.version) {
61 true } else if !tables_exist {
63 tracing::warn!(
65 "Migration {} is marked as applied but tables don't exist, re-applying",
66 migration.version
67 );
68 sqlx::query(&format!(
70 "DELETE FROM {}._duroxide_migrations WHERE version = $1",
71 self.schema_name
72 ))
73 .bind(migration.version)
74 .execute(&*self.pool)
75 .await?;
76 true
77 } else {
78 false };
80
81 if should_apply {
82 tracing::debug!(
83 "Applying migration {}: {}",
84 migration.version,
85 migration.name
86 );
87 self.apply_migration(&migration).await?;
88 } else {
89 tracing::debug!(
90 "Skipping migration {}: {} (already applied)",
91 migration.version,
92 migration.name
93 );
94 }
95 }
96
97 Ok(())
98 }
99
100 fn load_migrations(&self) -> Result<Vec<Migration>> {
102 let mut migrations = Vec::new();
103
104 let mut files: Vec<_> = MIGRATIONS
106 .files()
107 .filter(|file| {
108 file.path()
109 .extension()
110 .and_then(|ext| ext.to_str())
111 == Some("sql")
112 })
113 .collect();
114
115 files.sort_by_key(|f| f.path());
117
118 for file in files {
119 let file_name = file
120 .path()
121 .file_name()
122 .and_then(|n| n.to_str())
123 .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
124
125 let sql = file
126 .contents_utf8()
127 .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {}", file_name))?
128 .to_string();
129
130 let version = self.parse_version(file_name)?;
131 let name = file_name.to_string();
132
133 migrations.push(Migration { version, name, sql });
134 }
135
136 Ok(migrations)
137 }
138
139 fn parse_version(&self, filename: &str) -> Result<i64> {
141 let version_str = filename
142 .split('_')
143 .next()
144 .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {}", filename))?;
145
146 version_str
147 .parse::<i64>()
148 .map_err(|e| anyhow::anyhow!("Invalid migration version {}: {}", version_str, e))
149 }
150
151 async fn ensure_migration_table(&self) -> Result<()> {
153 sqlx::query(&format!(
155 r#"
156 CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
157 version BIGINT PRIMARY KEY,
158 name TEXT NOT NULL,
159 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
160 )
161 "#,
162 self.schema_name
163 ))
164 .execute(&*self.pool)
165 .await?;
166
167 Ok(())
168 }
169
170 async fn check_tables_exist(&self) -> Result<bool> {
172 let exists: bool = sqlx::query_scalar(&format!(
174 "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
175 ))
176 .bind(&self.schema_name)
177 .fetch_one(&*self.pool)
178 .await?;
179
180 Ok(exists)
181 }
182
183 async fn get_applied_versions(&self) -> Result<Vec<i64>> {
185 let versions: Vec<i64> = sqlx::query_scalar(&format!(
186 "SELECT version FROM {}._duroxide_migrations ORDER BY version",
187 self.schema_name
188 ))
189 .fetch_all(&*self.pool)
190 .await?;
191
192 Ok(versions)
193 }
194
195 fn split_sql_statements(sql: &str) -> Vec<String> {
198 let mut statements = Vec::new();
199 let mut current_statement = String::new();
200 let chars: Vec<char> = sql.chars().collect();
201 let mut i = 0;
202 let mut in_dollar_quote = false;
203 let mut dollar_tag: Option<String> = None;
204
205 while i < chars.len() {
206 let ch = chars[i];
207
208 if !in_dollar_quote {
209 if ch == '$' {
211 let mut tag = String::new();
212 tag.push(ch);
213 i += 1;
214
215 while i < chars.len() {
217 let next_ch = chars[i];
218 if next_ch == '$' {
219 tag.push(next_ch);
220 dollar_tag = Some(tag.clone());
221 in_dollar_quote = true;
222 current_statement.push_str(&tag);
223 i += 1;
224 break;
225 } else if next_ch.is_alphanumeric() || next_ch == '_' {
226 tag.push(next_ch);
227 i += 1;
228 } else {
229 current_statement.push(ch);
231 break;
232 }
233 }
234 } else if ch == ';' {
235 current_statement.push(ch);
237 let trimmed = current_statement.trim().to_string();
238 if !trimmed.is_empty() {
239 statements.push(trimmed);
240 }
241 current_statement.clear();
242 i += 1;
243 } else {
244 current_statement.push(ch);
245 i += 1;
246 }
247 } else {
248 current_statement.push(ch);
250
251 if ch == '$' {
253 let tag = dollar_tag.as_ref().unwrap();
254 let mut matches = true;
255
256 for (j, tag_char) in tag.chars().enumerate() {
258 if j == 0 {
259 continue; }
261 if i + j >= chars.len() || chars[i + j] != tag_char {
262 matches = false;
263 break;
264 }
265 }
266
267 if matches {
268 for _ in 0..(tag.len() - 1) {
270 if i + 1 < chars.len() {
271 current_statement.push(chars[i + 1]);
272 i += 1;
273 }
274 }
275 in_dollar_quote = false;
276 dollar_tag = None;
277 }
278 }
279 i += 1;
280 }
281 }
282
283 let trimmed = current_statement.trim().to_string();
285 if !trimmed.is_empty() {
286 statements.push(trimmed);
287 }
288
289 statements
290 }
291
292 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
294 let mut tx = self.pool.begin().await?;
296
297 sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
299 .execute(&mut *tx)
300 .await?;
301
302 let sql = migration.sql.trim();
304 let cleaned_sql: String = sql
305 .lines()
306 .map(|line| {
307 if let Some(idx) = line.find("--") {
309 let before = &line[..idx];
311 if before.matches('\'').count() % 2 == 0 {
312 line[..idx].trim()
314 } else {
315 line
316 }
317 } else {
318 line
319 }
320 })
321 .filter(|line| !line.is_empty())
322 .collect::<Vec<_>>()
323 .join("\n");
324
325 let statements = Self::split_sql_statements(&cleaned_sql);
327
328 tracing::debug!(
329 "Executing {} statements for migration {}",
330 statements.len(),
331 migration.version
332 );
333
334 for (idx, statement) in statements.iter().enumerate() {
335 if !statement.trim().is_empty() {
336 tracing::debug!(
337 "Executing statement {} of {}: {}...",
338 idx + 1,
339 statements.len(),
340 &statement.chars().take(50).collect::<String>()
341 );
342 sqlx::query(statement)
343 .execute(&mut *tx)
344 .await
345 .map_err(|e| {
346 anyhow::anyhow!(
347 "Failed to execute statement {} in migration {}: {}\nStatement: {}",
348 idx + 1,
349 migration.version,
350 e,
351 statement
352 )
353 })?;
354 }
355 }
356
357 sqlx::query(&format!(
359 "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
360 self.schema_name
361 ))
362 .bind(migration.version)
363 .bind(&migration.name)
364 .execute(&mut *tx)
365 .await?;
366
367 tx.commit().await?;
369
370 tracing::info!(
371 "Applied migration {}: {}",
372 migration.version,
373 migration.name
374 );
375
376 Ok(())
377 }
378}