1#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
12
13use sqlx::PgPool;
14use std::path::Path;
15use tracing::{debug, info};
16
17use crate::error::{ForgeError, Result};
18
19pub struct TestDatabase {
34 pool: PgPool,
35 url: String,
36}
37
38impl TestDatabase {
39 pub async fn from_url(url: &str) -> Result<Self> {
43 let pool = sqlx::postgres::PgPoolOptions::new()
44 .max_connections(10)
45 .connect(url)
46 .await
47 .map_err(ForgeError::Sql)?;
48
49 Ok(Self {
50 pool,
51 url: url.to_string(),
52 })
53 }
54
55 pub async fn from_env() -> Result<Self> {
60 let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
61 ForgeError::Database(
62 "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
63 )
64 })?;
65 Self::from_url(&url).await
66 }
67
68 pub fn pool(&self) -> &PgPool {
70 &self.pool
71 }
72
73 pub fn url(&self) -> &str {
75 &self.url
76 }
77
78 pub async fn execute(&self, sql: &str) -> Result<()> {
80 sqlx::query(sql)
81 .execute(&self.pool)
82 .await
83 .map_err(ForgeError::Sql)?;
84 Ok(())
85 }
86
87 pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
92 let base_url = self.url.clone();
93 let db_name = format!(
95 "forge_test_{}_{}",
96 sanitize_db_name(test_name),
97 uuid::Uuid::new_v4().simple()
98 );
99
100 let pool = sqlx::postgres::PgPoolOptions::new()
102 .max_connections(1)
103 .connect(&base_url)
104 .await
105 .map_err(ForgeError::Sql)?;
106
107 sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
109 .execute(&pool)
110 .await
111 .map_err(ForgeError::Sql)?;
112
113 let test_url = replace_db_name(&base_url, &db_name);
115
116 let test_pool = sqlx::postgres::PgPoolOptions::new()
117 .max_connections(5)
118 .connect(&test_url)
119 .await
120 .map_err(ForgeError::Sql)?;
121
122 Ok(IsolatedTestDb {
123 pool: test_pool,
124 db_name,
125 base_url,
126 })
127 }
128}
129
130pub struct IsolatedTestDb {
136 pool: PgPool,
137 db_name: String,
138 base_url: String,
139}
140
141impl IsolatedTestDb {
142 pub fn pool(&self) -> &PgPool {
144 &self.pool
145 }
146
147 pub fn db_name(&self) -> &str {
149 &self.db_name
150 }
151
152 pub async fn execute(&self, sql: &str) -> Result<()> {
154 sqlx::query(sql)
155 .execute(&self.pool)
156 .await
157 .map_err(ForgeError::Sql)?;
158 Ok(())
159 }
160
161 pub async fn run_sql(&self, sql: &str) -> Result<()> {
166 let statements = split_sql_statements(sql);
167 for statement in statements {
168 let statement = statement.trim();
169 if statement.is_empty()
170 || statement.lines().all(|l| {
171 let l = l.trim();
172 l.is_empty() || l.starts_with("--")
173 })
174 {
175 continue;
176 }
177
178 sqlx::query(statement)
179 .execute(&self.pool)
180 .await
181 .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {}", e)))?;
182 }
183 Ok(())
184 }
185
186 pub async fn cleanup(self) -> Result<()> {
191 self.pool.close().await;
193
194 let pool = sqlx::postgres::PgPoolOptions::new()
196 .max_connections(1)
197 .connect(&self.base_url)
198 .await
199 .map_err(ForgeError::Sql)?;
200
201 let _ = sqlx::query(&format!(
203 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
204 self.db_name
205 ))
206 .execute(&pool)
207 .await;
208
209 sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
210 .execute(&pool)
211 .await
212 .map_err(ForgeError::Sql)?;
213
214 Ok(())
215 }
216
217 pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
230 if !migrations_dir.exists() {
231 debug!("Migrations directory does not exist: {:?}", migrations_dir);
232 return Ok(());
233 }
234
235 let mut migrations = Vec::new();
236
237 let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
238
239 for entry in entries {
240 let entry = entry.map_err(ForgeError::Io)?;
241 let path = entry.path();
242
243 if path.extension().map(|e| e == "sql").unwrap_or(false) {
244 let name = path
245 .file_stem()
246 .and_then(|s| s.to_str())
247 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
248 .to_string();
249
250 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
251 migrations.push((name, content));
252 }
253 }
254
255 migrations.sort_by(|a, b| a.0.cmp(&b.0));
257
258 debug!("Running {} migrations for test", migrations.len());
259
260 for (name, content) in migrations {
261 info!("Applying test migration: {}", name);
262
263 let up_sql = parse_up_sql(&content);
265
266 let statements = split_sql_statements(&up_sql);
268 for statement in statements {
269 let statement = statement.trim();
270 if statement.is_empty()
271 || statement.lines().all(|l| {
272 let l = l.trim();
273 l.is_empty() || l.starts_with("--")
274 })
275 {
276 continue;
277 }
278
279 sqlx::query(statement)
280 .execute(&self.pool)
281 .await
282 .map_err(|e| {
283 ForgeError::Database(format!("Failed to apply migration '{}': {}", name, e))
284 })?;
285 }
286 }
287
288 Ok(())
289 }
290}
291
292fn sanitize_db_name(name: &str) -> String {
294 name.chars()
295 .map(|c| if c.is_alphanumeric() { c } else { '_' })
296 .take(32)
297 .collect()
298}
299
300fn replace_db_name(url: &str, new_db: &str) -> String {
302 if let Some(idx) = url.rfind('/') {
304 let base = &url[..=idx];
305 if let Some(query_idx) = url[idx + 1..].find('?') {
307 let query = &url[idx + 1 + query_idx..];
308 format!("{}{}{}", base, new_db, query)
309 } else {
310 format!("{}{}", base, new_db)
311 }
312 } else {
313 format!("{}/{}", url, new_db)
314 }
315}
316
317fn parse_up_sql(content: &str) -> String {
319 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
320
321 for pattern in down_marker_patterns {
322 if let Some(idx) = content.find(pattern) {
323 let up_part = &content[..idx];
324 return up_part
325 .replace("-- @up", "")
326 .replace("--@up", "")
327 .replace("-- @UP", "")
328 .replace("--@UP", "")
329 .trim()
330 .to_string();
331 }
332 }
333
334 content
336 .replace("-- @up", "")
337 .replace("--@up", "")
338 .replace("-- @UP", "")
339 .replace("--@UP", "")
340 .trim()
341 .to_string()
342}
343
344fn split_sql_statements(sql: &str) -> Vec<String> {
347 let mut statements = Vec::new();
348 let mut current = String::new();
349 let mut in_dollar_quote = false;
350 let mut dollar_tag = String::new();
351 let mut chars = sql.chars().peekable();
352
353 while let Some(c) = chars.next() {
354 current.push(c);
355
356 if c == '$' {
358 let mut potential_tag = String::from("$");
360
361 while let Some(&next_c) = chars.peek() {
363 if next_c == '$' {
364 potential_tag.push(chars.next().unwrap());
365 current.push('$');
366 break;
367 } else if next_c.is_alphanumeric() || next_c == '_' {
368 potential_tag.push(chars.next().unwrap());
369 current.push(potential_tag.chars().last().unwrap());
370 } else {
371 break;
372 }
373 }
374
375 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
377 if in_dollar_quote && potential_tag == dollar_tag {
378 in_dollar_quote = false;
380 dollar_tag.clear();
381 } else if !in_dollar_quote {
382 in_dollar_quote = true;
384 dollar_tag = potential_tag;
385 }
386 }
387 }
388
389 if c == ';' && !in_dollar_quote {
391 let stmt = current.trim().trim_end_matches(';').trim().to_string();
392 if !stmt.is_empty() {
393 statements.push(stmt);
394 }
395 current.clear();
396 }
397 }
398
399 let stmt = current.trim().trim_end_matches(';').trim().to_string();
401 if !stmt.is_empty() {
402 statements.push(stmt);
403 }
404
405 statements
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_sanitize_db_name() {
414 assert_eq!(sanitize_db_name("my_test"), "my_test");
415 assert_eq!(sanitize_db_name("my-test"), "my_test");
416 assert_eq!(sanitize_db_name("my test"), "my_test");
417 assert_eq!(sanitize_db_name("test::function"), "test__function");
418 }
419
420 #[test]
421 fn test_replace_db_name() {
422 assert_eq!(
423 replace_db_name("postgres://localhost/olddb", "newdb"),
424 "postgres://localhost/newdb"
425 );
426 assert_eq!(
427 replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
428 "postgres://user:pass@localhost:5432/newdb"
429 );
430 assert_eq!(
431 replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
432 "postgres://localhost/newdb?sslmode=disable"
433 );
434 }
435}