1#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
6
7use sqlx::PgPool;
8use std::path::Path;
9#[cfg(feature = "testcontainers")]
10use std::sync::Arc;
11use tracing::{debug, info};
12
13use crate::error::{ForgeError, Result};
14
15#[cfg(feature = "testcontainers")]
16type PgContainer =
17 Arc<Option<testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>>>;
18
19pub struct TestDatabase {
28 pool: PgPool,
29 url: String,
30 #[cfg(feature = "testcontainers")]
31 _container: PgContainer,
32}
33
34impl TestDatabase {
35 pub async fn from_url(url: &str) -> Result<Self> {
37 let pool = sqlx::postgres::PgPoolOptions::new()
38 .max_connections(10)
39 .connect(url)
40 .await
41 .map_err(ForgeError::Sql)?;
42
43 Ok(Self {
44 pool,
45 url: url.to_string(),
46 #[cfg(feature = "testcontainers")]
47 _container: Arc::new(None),
48 })
49 }
50
51 pub async fn from_env() -> Result<Self> {
54 match std::env::var("TEST_DATABASE_URL") {
55 Ok(url) => Self::from_url(&url).await,
56 Err(_) => {
57 #[cfg(feature = "testcontainers")]
58 {
59 Self::from_container().await
60 }
61 #[cfg(not(feature = "testcontainers"))]
62 {
63 Err(ForgeError::Database(
64 "TEST_DATABASE_URL not set. Set it explicitly for database tests, \
65 or enable the `testcontainers` feature for automatic provisioning."
66 .to_string(),
67 ))
68 }
69 }
70 }
71 }
72
73 #[cfg(feature = "testcontainers")]
74 async fn from_container() -> Result<Self> {
75 use testcontainers::ImageExt;
76 use testcontainers::runners::AsyncRunner;
77 use testcontainers_modules::postgres::Postgres;
78
79 let container = Postgres::default()
81 .with_tag("18-alpine")
82 .start()
83 .await
84 .map_err(|e| ForgeError::Database(format!("Failed to start PG container: {e}")))?;
85
86 let port = container
87 .get_host_port_ipv4(5432)
88 .await
89 .map_err(|e| ForgeError::Database(format!("Failed to get container port: {e}")))?;
90
91 let url = format!("postgres://postgres:postgres@localhost:{port}/postgres");
92 let pool = sqlx::postgres::PgPoolOptions::new()
93 .max_connections(10)
94 .acquire_timeout(std::time::Duration::from_secs(30))
95 .connect(&url)
96 .await
97 .map_err(ForgeError::Sql)?;
98
99 Ok(Self {
100 pool,
101 url,
102 _container: Arc::new(Some(container)),
103 })
104 }
105
106 pub fn pool(&self) -> &PgPool {
108 &self.pool
109 }
110
111 pub fn url(&self) -> &str {
113 &self.url
114 }
115
116 pub async fn execute(&self, sql: &str) -> Result<()> {
118 sqlx::query(sql)
119 .execute(&self.pool)
120 .await
121 .map_err(ForgeError::Sql)?;
122 Ok(())
123 }
124
125 pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
130 let base_url = self.url.clone();
131 let db_name = format!(
133 "forge_test_{}_{}",
134 sanitize_db_name(test_name),
135 uuid::Uuid::new_v4().simple()
136 );
137
138 let pool = sqlx::postgres::PgPoolOptions::new()
140 .max_connections(1)
141 .connect(&base_url)
142 .await
143 .map_err(ForgeError::Sql)?;
144
145 sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
147 .execute(&pool)
148 .await
149 .map_err(ForgeError::Sql)?;
150
151 let test_url = replace_db_name(&base_url, &db_name);
153
154 let test_pool = sqlx::postgres::PgPoolOptions::new()
155 .max_connections(5)
156 .connect(&test_url)
157 .await
158 .map_err(ForgeError::Sql)?;
159
160 Ok(IsolatedTestDb {
161 pool: test_pool,
162 db_name,
163 base_url,
164 #[cfg(feature = "testcontainers")]
165 _container: self._container.clone(),
166 })
167 }
168}
169
170pub struct IsolatedTestDb {
176 pool: PgPool,
177 db_name: String,
178 base_url: String,
179 #[cfg(feature = "testcontainers")]
180 _container: PgContainer,
181}
182
183impl IsolatedTestDb {
184 pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result<Self> {
198 let base = TestDatabase::from_env().await?;
199 let db = base.isolated(test_name).await?;
200 db.run_sql(internal_sql).await?;
201 db.migrate(migrations_dir).await?;
202 Ok(db)
203 }
204
205 pub fn pool(&self) -> &PgPool {
207 &self.pool
208 }
209
210 pub fn db_name(&self) -> &str {
212 &self.db_name
213 }
214
215 pub async fn execute(&self, sql: &str) -> Result<()> {
217 sqlx::query(sql)
218 .execute(&self.pool)
219 .await
220 .map_err(ForgeError::Sql)?;
221 Ok(())
222 }
223
224 pub async fn run_sql(&self, sql: &str) -> Result<()> {
229 for stmt in split_sql_statements(sql) {
230 let stmt = stmt.trim();
231 if is_blank_sql(stmt) {
232 continue;
233 }
234 sqlx::query(stmt)
235 .execute(&self.pool)
236 .await
237 .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {e}")))?;
238 }
239 Ok(())
240 }
241
242 pub async fn cleanup(self) -> Result<()> {
247 self.pool.close().await;
249
250 let pool = sqlx::postgres::PgPoolOptions::new()
252 .max_connections(1)
253 .connect(&self.base_url)
254 .await
255 .map_err(ForgeError::Sql)?;
256
257 let _ = sqlx::query(&format!(
259 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
260 self.db_name
261 ))
262 .execute(&pool)
263 .await;
264
265 sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
266 .execute(&pool)
267 .await
268 .map_err(ForgeError::Sql)?;
269
270 Ok(())
271 }
272
273 pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
286 if !migrations_dir.exists() {
287 debug!("Migrations directory does not exist: {:?}", migrations_dir);
288 return Ok(());
289 }
290
291 let mut migrations = Vec::new();
292
293 let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
294
295 for entry in entries {
296 let entry = entry.map_err(ForgeError::Io)?;
297 let path = entry.path();
298
299 if path.extension().map(|e| e == "sql").unwrap_or(false) {
300 let name = path
301 .file_stem()
302 .and_then(|s| s.to_str())
303 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
304 .to_string();
305
306 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
307 migrations.push((name, content));
308 }
309 }
310
311 migrations.sort_by(|a, b| a.0.cmp(&b.0));
313
314 debug!("Running {} migrations for test", migrations.len());
315
316 for (name, content) in migrations {
317 info!("Applying test migration: {}", name);
318
319 let up_sql = parse_up_sql(&content);
321
322 for stmt in split_sql_statements(&up_sql) {
324 let stmt = stmt.trim();
325 if is_blank_sql(stmt) {
326 continue;
327 }
328 sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
329 ForgeError::Database(format!("Failed to apply migration '{name}': {e}"))
330 })?;
331 }
332 }
333
334 Ok(())
335 }
336}
337
338fn is_blank_sql(sql: &str) -> bool {
339 sql.is_empty()
340 || sql
341 .lines()
342 .all(|l| l.trim().is_empty() || l.trim().starts_with("--"))
343}
344
345fn sanitize_db_name(name: &str) -> String {
347 name.chars()
348 .map(|c| if c.is_alphanumeric() { c } else { '_' })
349 .take(32)
350 .collect()
351}
352
353fn replace_db_name(url: &str, new_db: &str) -> String {
355 if let Some(idx) = url.rfind('/') {
357 let base = &url[..=idx];
358 if let Some(query_idx) = url[idx + 1..].find('?') {
360 let query = &url[idx + 1 + query_idx..];
361 format!("{}{}{}", base, new_db, query)
362 } else {
363 format!("{}{}", base, new_db)
364 }
365 } else {
366 format!("{}/{}", url, new_db)
367 }
368}
369
370fn parse_up_sql(content: &str) -> String {
372 let down_markers = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
373 let up_part = down_markers
374 .iter()
375 .find_map(|m| content.find(m).map(|idx| &content[..idx]))
376 .unwrap_or(content);
377
378 strip_up_markers(up_part)
379}
380
381fn strip_up_markers(sql: &str) -> String {
382 sql.replace("-- @up", "")
383 .replace("--@up", "")
384 .replace("-- @UP", "")
385 .replace("--@UP", "")
386 .trim()
387 .to_string()
388}
389
390fn split_sql_statements(sql: &str) -> Vec<String> {
393 let mut statements = Vec::new();
394 let mut current = String::new();
395 let mut in_dollar_quote = false;
396 let mut dollar_tag = String::new();
397 let mut chars = sql.chars().peekable();
398
399 while let Some(c) = chars.next() {
400 current.push(c);
401
402 if c == '$' {
404 let mut potential_tag = String::from("$");
406
407 while let Some(&next_c) = chars.peek() {
409 if next_c == '$' {
410 potential_tag.push(chars.next().unwrap());
411 current.push('$');
412 break;
413 } else if next_c.is_alphanumeric() || next_c == '_' {
414 potential_tag.push(chars.next().unwrap());
415 current.push(potential_tag.chars().last().unwrap());
416 } else {
417 break;
418 }
419 }
420
421 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
423 if in_dollar_quote && potential_tag == dollar_tag {
424 in_dollar_quote = false;
426 dollar_tag.clear();
427 } else if !in_dollar_quote {
428 in_dollar_quote = true;
430 dollar_tag = potential_tag;
431 }
432 }
433 }
434
435 if c == ';' && !in_dollar_quote {
437 let stmt = current.trim().trim_end_matches(';').trim().to_string();
438 if !stmt.is_empty() {
439 statements.push(stmt);
440 }
441 current.clear();
442 }
443 }
444
445 let stmt = current.trim().trim_end_matches(';').trim().to_string();
447 if !stmt.is_empty() {
448 statements.push(stmt);
449 }
450
451 statements
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_sanitize_db_name() {
460 assert_eq!(sanitize_db_name("my_test"), "my_test");
461 assert_eq!(sanitize_db_name("my-test"), "my_test");
462 assert_eq!(sanitize_db_name("my test"), "my_test");
463 assert_eq!(sanitize_db_name("test::function"), "test__function");
464 }
465
466 #[test]
467 fn test_replace_db_name() {
468 assert_eq!(
469 replace_db_name("postgres://localhost/olddb", "newdb"),
470 "postgres://localhost/newdb"
471 );
472 assert_eq!(
473 replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
474 "postgres://user:pass@localhost:5432/newdb"
475 );
476 assert_eq!(
477 replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
478 "postgres://localhost/newdb?sslmode=disable"
479 );
480 }
481}