duroxide_pg_opt/
migrations.rs1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::PgPool;
4use std::sync::Arc;
5
6static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
9
10#[derive(Debug)]
12struct Migration {
13 version: i64,
14 name: String,
15 sql: String,
16}
17
18pub struct MigrationRunner {
20 pool: Arc<PgPool>,
21 schema_name: String,
22}
23
24impl MigrationRunner {
25 pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
27 Self { pool, schema_name }
28 }
29
30 pub async fn migrate(&self) -> Result<()> {
32 if self.schema_name != "public" {
34 sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
35 .execute(&*self.pool)
36 .await?;
37 }
38
39 let migrations = self.load_migrations()?;
41
42 tracing::debug!(
43 "Loaded {} migrations for schema {}",
44 migrations.len(),
45 self.schema_name
46 );
47
48 self.ensure_migration_table().await?;
50
51 let applied_versions = self.get_applied_versions().await?;
53
54 tracing::debug!("Applied migrations: {:?}", applied_versions);
55
56 let tables_exist = self.check_tables_exist().await.unwrap_or(false);
59
60 for migration in migrations {
62 let should_apply = if !applied_versions.contains(&migration.version) {
63 true } else if !tables_exist {
65 tracing::warn!(
67 "Migration {} is marked as applied but tables don't exist, re-applying",
68 migration.version
69 );
70 sqlx::query(&format!(
72 "DELETE FROM {}._duroxide_migrations WHERE version = $1",
73 self.schema_name
74 ))
75 .bind(migration.version)
76 .execute(&*self.pool)
77 .await?;
78 true
79 } else {
80 false };
82
83 if should_apply {
84 tracing::debug!(
85 "Applying migration {}: {}",
86 migration.version,
87 migration.name
88 );
89 self.apply_migration(&migration).await?;
90 } else {
91 tracing::debug!(
92 "Skipping migration {}: {} (already applied)",
93 migration.version,
94 migration.name
95 );
96 }
97 }
98
99 Ok(())
100 }
101
102 fn load_migrations(&self) -> Result<Vec<Migration>> {
104 let mut migrations = Vec::new();
105
106 let mut files: Vec<_> = MIGRATIONS
108 .files()
109 .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
110 .collect();
111
112 files.sort_by_key(|f| f.path());
114
115 for file in files {
116 let file_name = file
117 .path()
118 .file_name()
119 .and_then(|n| n.to_str())
120 .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
121
122 let sql = file
123 .contents_utf8()
124 .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
125 .to_string();
126
127 let version = self.parse_version(file_name)?;
128 let name = file_name.to_string();
129
130 migrations.push(Migration { version, name, sql });
131 }
132
133 Ok(migrations)
134 }
135
136 fn parse_version(&self, filename: &str) -> Result<i64> {
138 let version_str = filename
139 .split('_')
140 .next()
141 .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
142
143 version_str
144 .parse::<i64>()
145 .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
146 }
147
148 async fn ensure_migration_table(&self) -> Result<()> {
150 sqlx::query(&format!(
152 r#"
153 CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
154 version BIGINT PRIMARY KEY,
155 name TEXT NOT NULL,
156 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
157 )
158 "#,
159 self.schema_name
160 ))
161 .execute(&*self.pool)
162 .await?;
163
164 Ok(())
165 }
166
167 async fn check_tables_exist(&self) -> Result<bool> {
169 let exists: bool = sqlx::query_scalar(
171 "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
172 )
173 .bind(&self.schema_name)
174 .fetch_one(&*self.pool)
175 .await?;
176
177 Ok(exists)
178 }
179
180 async fn get_applied_versions(&self) -> Result<Vec<i64>> {
182 let versions: Vec<i64> = sqlx::query_scalar(&format!(
183 "SELECT version FROM {}._duroxide_migrations ORDER BY version",
184 self.schema_name
185 ))
186 .fetch_all(&*self.pool)
187 .await?;
188
189 Ok(versions)
190 }
191
192 fn split_sql_statements(sql: &str) -> Vec<String> {
195 let mut statements = Vec::new();
196 let mut current_statement = String::new();
197 let chars: Vec<char> = sql.chars().collect();
198 let mut i = 0;
199 let mut in_dollar_quote = false;
200 let mut dollar_tag: Option<String> = None;
201
202 while i < chars.len() {
203 let ch = chars[i];
204
205 if !in_dollar_quote {
206 if ch == '$' {
208 let mut tag = String::new();
209 tag.push(ch);
210 i += 1;
211
212 while i < chars.len() {
214 let next_ch = chars[i];
215 if next_ch == '$' {
216 tag.push(next_ch);
217 dollar_tag = Some(tag.clone());
218 in_dollar_quote = true;
219 current_statement.push_str(&tag);
220 i += 1;
221 break;
222 } else if next_ch.is_alphanumeric() || next_ch == '_' {
223 tag.push(next_ch);
224 i += 1;
225 } else {
226 current_statement.push(ch);
228 break;
229 }
230 }
231 } else if ch == ';' {
232 current_statement.push(ch);
234 let trimmed = current_statement.trim().to_string();
235 if !trimmed.is_empty() {
236 statements.push(trimmed);
237 }
238 current_statement.clear();
239 i += 1;
240 } else {
241 current_statement.push(ch);
242 i += 1;
243 }
244 } else {
245 current_statement.push(ch);
247
248 if ch == '$' {
250 let tag = dollar_tag.as_ref().unwrap();
251 let mut matches = true;
252
253 for (j, tag_char) in tag.chars().enumerate() {
255 if j == 0 {
256 continue; }
258 if i + j >= chars.len() || chars[i + j] != tag_char {
259 matches = false;
260 break;
261 }
262 }
263
264 if matches {
265 for _ in 0..(tag.len() - 1) {
267 if i + 1 < chars.len() {
268 current_statement.push(chars[i + 1]);
269 i += 1;
270 }
271 }
272 in_dollar_quote = false;
273 dollar_tag = None;
274 }
275 }
276 i += 1;
277 }
278 }
279
280 let trimmed = current_statement.trim().to_string();
282 if !trimmed.is_empty() {
283 statements.push(trimmed);
284 }
285
286 statements
287 }
288
289 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
291 let mut tx = self.pool.begin().await?;
293
294 sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
296 .execute(&mut *tx)
297 .await?;
298
299 let sql = migration.sql.trim();
301 let cleaned_sql: String = sql
302 .lines()
303 .map(|line| {
304 if let Some(idx) = line.find("--") {
306 let before = &line[..idx];
308 if before.matches('\'').count() % 2 == 0 {
309 line[..idx].trim()
311 } else {
312 line
313 }
314 } else {
315 line
316 }
317 })
318 .filter(|line| !line.is_empty())
319 .collect::<Vec<_>>()
320 .join("\n");
321
322 let statements = Self::split_sql_statements(&cleaned_sql);
324
325 tracing::debug!(
326 "Executing {} statements for migration {}",
327 statements.len(),
328 migration.version
329 );
330
331 for (idx, statement) in statements.iter().enumerate() {
332 if !statement.trim().is_empty() {
333 tracing::debug!(
334 "Executing statement {} of {}: {}...",
335 idx + 1,
336 statements.len(),
337 &statement.chars().take(50).collect::<String>()
338 );
339 sqlx::query(statement)
340 .execute(&mut *tx)
341 .await
342 .map_err(|e| {
343 anyhow::anyhow!(
344 "Failed to execute statement {} in migration {}: {}\nStatement: {}",
345 idx + 1,
346 migration.version,
347 e,
348 statement
349 )
350 })?;
351 }
352 }
353
354 sqlx::query(&format!(
356 "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
357 self.schema_name
358 ))
359 .bind(migration.version)
360 .bind(&migration.name)
361 .execute(&mut *tx)
362 .await?;
363
364 tx.commit().await?;
366
367 tracing::info!(
368 "Applied migration {}: {}",
369 migration.version,
370 migration.name
371 );
372
373 Ok(())
374 }
375}