1use sqlx::PgPool;
13
14use crate::error::{ForgeError, Result};
15
16#[cfg(feature = "embedded-db")]
17use tokio::sync::OnceCell;
18
19#[cfg(feature = "embedded-db")]
20static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
21
22pub struct TestDatabase {
41 pool: PgPool,
42 url: String,
43}
44
45impl TestDatabase {
46 pub async fn from_url(url: &str) -> Result<Self> {
50 let pool = sqlx::postgres::PgPoolOptions::new()
51 .max_connections(10)
52 .connect(url)
53 .await
54 .map_err(ForgeError::Sql)?;
55
56 Ok(Self {
57 pool,
58 url: url.to_string(),
59 })
60 }
61
62 pub async fn from_env() -> Result<Self> {
67 let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
68 ForgeError::Database(
69 "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
70 )
71 })?;
72 Self::from_url(&url).await
73 }
74
75 #[cfg(feature = "embedded-db")]
80 pub async fn embedded() -> Result<Self> {
81 let pg = EMBEDDED_PG
82 .get_or_try_init(|| async {
83 let mut pg = postgresql_embedded::PostgreSQL::default();
84 pg.setup().await.map_err(|e| {
85 ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
86 })?;
87 pg.start().await.map_err(|e| {
88 ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
89 })?;
90 Ok::<_, ForgeError>(pg)
91 })
92 .await?;
93
94 let url = pg.settings().url("postgres");
95 Self::from_url(&url).await
96 }
97
98 pub fn pool(&self) -> &PgPool {
100 &self.pool
101 }
102
103 pub fn url(&self) -> &str {
105 &self.url
106 }
107
108 pub async fn execute(&self, sql: &str) -> Result<()> {
110 sqlx::query(sql)
111 .execute(&self.pool)
112 .await
113 .map_err(ForgeError::Sql)?;
114 Ok(())
115 }
116
117 pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
122 let base_url = self.url.clone();
123 let db_name = format!(
125 "forge_test_{}_{}",
126 sanitize_db_name(test_name),
127 uuid::Uuid::new_v4().simple()
128 );
129
130 let pool = sqlx::postgres::PgPoolOptions::new()
132 .max_connections(1)
133 .connect(&base_url)
134 .await
135 .map_err(ForgeError::Sql)?;
136
137 sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
139 .execute(&pool)
140 .await
141 .map_err(ForgeError::Sql)?;
142
143 let test_url = replace_db_name(&base_url, &db_name);
145
146 let test_pool = sqlx::postgres::PgPoolOptions::new()
147 .max_connections(5)
148 .connect(&test_url)
149 .await
150 .map_err(ForgeError::Sql)?;
151
152 Ok(IsolatedTestDb {
153 pool: test_pool,
154 db_name,
155 base_url,
156 })
157 }
158}
159
160pub struct IsolatedTestDb {
166 pool: PgPool,
167 db_name: String,
168 base_url: String,
169}
170
171impl IsolatedTestDb {
172 pub fn pool(&self) -> &PgPool {
174 &self.pool
175 }
176
177 pub fn db_name(&self) -> &str {
179 &self.db_name
180 }
181
182 pub async fn execute(&self, sql: &str) -> Result<()> {
184 sqlx::query(sql)
185 .execute(&self.pool)
186 .await
187 .map_err(ForgeError::Sql)?;
188 Ok(())
189 }
190
191 pub async fn cleanup(self) -> Result<()> {
196 self.pool.close().await;
198
199 let pool = sqlx::postgres::PgPoolOptions::new()
201 .max_connections(1)
202 .connect(&self.base_url)
203 .await
204 .map_err(ForgeError::Sql)?;
205
206 let _ = sqlx::query(&format!(
208 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
209 self.db_name
210 ))
211 .execute(&pool)
212 .await;
213
214 sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
215 .execute(&pool)
216 .await
217 .map_err(ForgeError::Sql)?;
218
219 Ok(())
220 }
221}
222
223fn sanitize_db_name(name: &str) -> String {
225 name.chars()
226 .map(|c| if c.is_alphanumeric() { c } else { '_' })
227 .take(32)
228 .collect()
229}
230
231fn replace_db_name(url: &str, new_db: &str) -> String {
233 if let Some(idx) = url.rfind('/') {
235 let base = &url[..=idx];
236 if let Some(query_idx) = url[idx + 1..].find('?') {
238 let query = &url[idx + 1 + query_idx..];
239 format!("{}{}{}", base, new_db, query)
240 } else {
241 format!("{}{}", base, new_db)
242 }
243 } else {
244 format!("{}/{}", url, new_db)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_sanitize_db_name() {
254 assert_eq!(sanitize_db_name("my_test"), "my_test");
255 assert_eq!(sanitize_db_name("my-test"), "my_test");
256 assert_eq!(sanitize_db_name("my test"), "my_test");
257 assert_eq!(sanitize_db_name("test::function"), "test__function");
258 }
259
260 #[test]
261 fn test_replace_db_name() {
262 assert_eq!(
263 replace_db_name("postgres://localhost/olddb", "newdb"),
264 "postgres://localhost/newdb"
265 );
266 assert_eq!(
267 replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
268 "postgres://user:pass@localhost:5432/newdb"
269 );
270 assert_eq!(
271 replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
272 "postgres://localhost/newdb?sslmode=disable"
273 );
274 }
275}