duroxide_pg/
migrations.rs1use anyhow::Result;
2use include_dir::{include_dir, Dir};
3use sqlx::PgPool;
4use std::sync::Arc;
5
6static MIGRATIONS: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
7
8#[derive(Debug)]
10struct Migration {
11 version: i64,
12 name: String,
13 sql: String,
14}
15
16pub struct MigrationRunner {
18 pool: Arc<PgPool>,
19 schema_name: String,
20}
21
22impl MigrationRunner {
23 pub fn new(pool: Arc<PgPool>, schema_name: String) -> Self {
25 Self { pool, schema_name }
26 }
27
28 pub async fn migrate(&self) -> Result<()> {
30 if self.schema_name != "public" {
32 sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {}", self.schema_name))
33 .execute(&*self.pool)
34 .await?;
35 }
36
37 let migrations = self.load_migrations()?;
39
40 tracing::debug!(
41 "Loaded {} migrations for schema {}",
42 migrations.len(),
43 self.schema_name
44 );
45
46 self.ensure_migration_table().await?;
48
49 let applied_versions = self.get_applied_versions().await?;
51
52 tracing::debug!("Applied migrations: {:?}", applied_versions);
53
54 let tables_exist = self.check_tables_exist().await.unwrap_or(false);
57
58 for migration in migrations {
60 let should_apply = if !applied_versions.contains(&migration.version) {
61 true } else if !tables_exist {
63 tracing::warn!(
65 "Migration {} is marked as applied but tables don't exist, re-applying",
66 migration.version
67 );
68 sqlx::query(&format!(
70 "DELETE FROM {}._duroxide_migrations WHERE version = $1",
71 self.schema_name
72 ))
73 .bind(migration.version)
74 .execute(&*self.pool)
75 .await?;
76 true
77 } else {
78 false };
80
81 if should_apply {
82 tracing::debug!(
83 "Applying migration {}: {}",
84 migration.version,
85 migration.name
86 );
87 self.apply_migration(&migration).await?;
88 } else {
89 tracing::debug!(
90 "Skipping migration {}: {} (already applied)",
91 migration.version,
92 migration.name
93 );
94 }
95 }
96
97 Ok(())
98 }
99
100 fn load_migrations(&self) -> Result<Vec<Migration>> {
102 let mut migrations = Vec::new();
103
104 let mut files: Vec<_> = MIGRATIONS
106 .files()
107 .filter(|file| file.path().extension().and_then(|ext| ext.to_str()) == Some("sql"))
108 .collect();
109
110 files.sort_by_key(|f| f.path());
112
113 for file in files {
114 let file_name = file
115 .path()
116 .file_name()
117 .and_then(|n| n.to_str())
118 .ok_or_else(|| anyhow::anyhow!("Invalid filename in migrations"))?;
119
120 let sql = file
121 .contents_utf8()
122 .ok_or_else(|| anyhow::anyhow!("Migration file is not valid UTF-8: {file_name}"))?
123 .to_string();
124
125 let version = self.parse_version(file_name)?;
126 let name = file_name.to_string();
127
128 migrations.push(Migration { version, name, sql });
129 }
130
131 Ok(migrations)
132 }
133
134 fn parse_version(&self, filename: &str) -> Result<i64> {
136 let version_str = filename
137 .split('_')
138 .next()
139 .ok_or_else(|| anyhow::anyhow!("Invalid migration filename: {filename}"))?;
140
141 version_str
142 .parse::<i64>()
143 .map_err(|e| anyhow::anyhow!("Invalid migration version {version_str}: {e}"))
144 }
145
146 async fn ensure_migration_table(&self) -> Result<()> {
148 sqlx::query(&format!(
150 r#"
151 CREATE TABLE IF NOT EXISTS {}._duroxide_migrations (
152 version BIGINT PRIMARY KEY,
153 name TEXT NOT NULL,
154 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
155 )
156 "#,
157 self.schema_name
158 ))
159 .execute(&*self.pool)
160 .await?;
161
162 Ok(())
163 }
164
165 async fn check_tables_exist(&self) -> Result<bool> {
167 let exists: bool = sqlx::query_scalar(
169 "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = 'instances')",
170 )
171 .bind(&self.schema_name)
172 .fetch_one(&*self.pool)
173 .await?;
174
175 Ok(exists)
176 }
177
178 async fn get_applied_versions(&self) -> Result<Vec<i64>> {
180 let versions: Vec<i64> = sqlx::query_scalar(&format!(
181 "SELECT version FROM {}._duroxide_migrations ORDER BY version",
182 self.schema_name
183 ))
184 .fetch_all(&*self.pool)
185 .await?;
186
187 Ok(versions)
188 }
189
190 fn split_sql_statements(sql: &str) -> Vec<String> {
193 let mut statements = Vec::new();
194 let mut current_statement = String::new();
195 let chars: Vec<char> = sql.chars().collect();
196 let mut i = 0;
197 let mut in_dollar_quote = false;
198 let mut dollar_tag: Option<String> = None;
199
200 while i < chars.len() {
201 let ch = chars[i];
202
203 if !in_dollar_quote {
204 if ch == '$' {
206 let mut tag = String::new();
207 tag.push(ch);
208 i += 1;
209
210 while i < chars.len() {
212 let next_ch = chars[i];
213 if next_ch == '$' {
214 tag.push(next_ch);
215 dollar_tag = Some(tag.clone());
216 in_dollar_quote = true;
217 current_statement.push_str(&tag);
218 i += 1;
219 break;
220 } else if next_ch.is_alphanumeric() || next_ch == '_' {
221 tag.push(next_ch);
222 i += 1;
223 } else {
224 current_statement.push(ch);
226 break;
227 }
228 }
229 } else if ch == ';' {
230 current_statement.push(ch);
232 let trimmed = current_statement.trim().to_string();
233 if !trimmed.is_empty() {
234 statements.push(trimmed);
235 }
236 current_statement.clear();
237 i += 1;
238 } else {
239 current_statement.push(ch);
240 i += 1;
241 }
242 } else {
243 current_statement.push(ch);
245
246 if ch == '$' {
248 let tag = dollar_tag.as_ref().unwrap();
249 let mut matches = true;
250
251 for (j, tag_char) in tag.chars().enumerate() {
253 if j == 0 {
254 continue; }
256 if i + j >= chars.len() || chars[i + j] != tag_char {
257 matches = false;
258 break;
259 }
260 }
261
262 if matches {
263 for _ in 0..(tag.len() - 1) {
265 if i + 1 < chars.len() {
266 current_statement.push(chars[i + 1]);
267 i += 1;
268 }
269 }
270 in_dollar_quote = false;
271 dollar_tag = None;
272 }
273 }
274 i += 1;
275 }
276 }
277
278 let trimmed = current_statement.trim().to_string();
280 if !trimmed.is_empty() {
281 statements.push(trimmed);
282 }
283
284 statements
285 }
286
287 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
289 let mut tx = self.pool.begin().await?;
291
292 sqlx::query(&format!("SET LOCAL search_path TO {}", self.schema_name))
294 .execute(&mut *tx)
295 .await?;
296
297 let sql = migration.sql.trim();
299 let cleaned_sql: String = sql
300 .lines()
301 .map(|line| {
302 if let Some(idx) = line.find("--") {
304 let before = &line[..idx];
306 if before.matches('\'').count() % 2 == 0 {
307 line[..idx].trim()
309 } else {
310 line
311 }
312 } else {
313 line
314 }
315 })
316 .filter(|line| !line.is_empty())
317 .collect::<Vec<_>>()
318 .join("\n");
319
320 let statements = Self::split_sql_statements(&cleaned_sql);
322
323 tracing::debug!(
324 "Executing {} statements for migration {}",
325 statements.len(),
326 migration.version
327 );
328
329 for (idx, statement) in statements.iter().enumerate() {
330 if !statement.trim().is_empty() {
331 tracing::debug!(
332 "Executing statement {} of {}: {}...",
333 idx + 1,
334 statements.len(),
335 &statement.chars().take(50).collect::<String>()
336 );
337 sqlx::query(statement)
338 .execute(&mut *tx)
339 .await
340 .map_err(|e| {
341 anyhow::anyhow!(
342 "Failed to execute statement {} in migration {}: {}\nStatement: {}",
343 idx + 1,
344 migration.version,
345 e,
346 statement
347 )
348 })?;
349 }
350 }
351
352 sqlx::query(&format!(
354 "INSERT INTO {}._duroxide_migrations (version, name) VALUES ($1, $2)",
355 self.schema_name
356 ))
357 .bind(migration.version)
358 .bind(&migration.name)
359 .execute(&mut *tx)
360 .await?;
361
362 tx.commit().await?;
364
365 tracing::info!(
366 "Applied migration {}: {}",
367 migration.version,
368 migration.name
369 );
370
371 Ok(())
372 }
373}