1use std::sync::Arc;
10use sqlx::{PgPool, Postgres, Transaction, Row};
11use serde_json::Value as JsonValue;
12use crate::{TestError, TestResult};
13
14#[derive(Clone)]
16pub struct TestDatabase {
17 pool: PgPool,
18 transaction: Option<Arc<std::sync::Mutex<Transaction<'static, Postgres>>>>,
19}
20
21impl TestDatabase {
22 pub async fn new() -> TestResult<Self> {
27 let database_url = std::env::var("TEST_DATABASE_URL")
28 .or_else(|_| std::env::var("DATABASE_URL"))
29 .unwrap_or_else(|_| "postgresql://postgres:postgres@localhost:5432/elif_test".to_string());
30
31 let pool = PgPool::connect(&database_url).await?;
32
33 Self::ensure_test_database(&pool).await?;
35
36 Ok(Self {
37 pool,
38 transaction: None,
39 })
40 }
41
42 pub async fn with_transaction() -> TestResult<Self> {
46 let database_url = std::env::var("TEST_DATABASE_URL")
47 .or_else(|_| std::env::var("DATABASE_URL"))
48 .unwrap_or_else(|_| "postgresql://postgres:postgres@localhost:5432/elif_test".to_string());
49
50 let pool = PgPool::connect(&database_url).await?;
51 Self::ensure_test_database(&pool).await?;
52
53 let transaction = pool.begin().await?;
54
55 Ok(Self {
56 pool,
57 transaction: Some(Arc::new(std::sync::Mutex::new(transaction))),
58 })
59 }
60
61 pub fn pool(&self) -> &PgPool {
63 &self.pool
64 }
65
66 pub async fn execute(&self, sql: &str) -> TestResult<()> {
68 sqlx::query(sql).execute(&self.pool).await?;
69 Ok(())
70 }
71
72 pub async fn fetch_one(&self, sql: &str) -> TestResult<sqlx::postgres::PgRow> {
74 let row = sqlx::query(sql).fetch_one(&self.pool).await?;
75 Ok(row)
76 }
77
78 pub async fn fetch_all(&self, sql: &str) -> TestResult<Vec<sqlx::postgres::PgRow>> {
80 let rows = sqlx::query(sql).fetch_all(&self.pool).await?;
81 Ok(rows)
82 }
83
84 pub async fn record_exists(&self, table: &str, conditions: &[(&str, &dyn ToString)]) -> TestResult<bool> {
86 let mut query = format!("SELECT 1 FROM {} WHERE", table);
87 let mut params = Vec::new();
88
89 for (i, (column, value)) in conditions.iter().enumerate() {
90 if i > 0 {
91 query.push_str(" AND");
92 }
93 query.push_str(&format!(" {} = ${}", column, i + 1));
94 params.push(value.to_string());
95 }
96
97 let mut sql_query = sqlx::query(&query);
98 for param in params {
99 sql_query = sql_query.bind(param);
100 }
101
102 let result = sql_query.fetch_optional(&self.pool).await?;
103 Ok(result.is_some())
104 }
105
106 pub async fn count_records(&self, table: &str, conditions: &[(&str, &dyn ToString)]) -> TestResult<i64> {
108 let mut query = format!("SELECT COUNT(*) FROM {}", table);
109 let mut params = Vec::new();
110
111 if !conditions.is_empty() {
112 query.push_str(" WHERE");
113 for (i, (column, value)) in conditions.iter().enumerate() {
114 if i > 0 {
115 query.push_str(" AND");
116 }
117 query.push_str(&format!(" {} = ${}", column, i + 1));
118 params.push(value.to_string());
119 }
120 }
121
122 let mut sql_query = sqlx::query_scalar(&query);
123 for param in params {
124 sql_query = sql_query.bind(param);
125 }
126
127 let count: i64 = sql_query.fetch_one(&self.pool).await?;
128 Ok(count)
129 }
130
131 pub async fn cleanup(&self) -> TestResult<()> {
133 let tables_query = r#"
134 SELECT tablename FROM pg_tables
135 WHERE schemaname = 'public'
136 AND tablename != '_sqlx_migrations'
137 "#;
138
139 let rows = sqlx::query(tables_query).fetch_all(&self.pool).await?;
140
141 for row in rows {
142 let table_name: String = row.get("tablename");
143 let truncate_sql = format!("TRUNCATE TABLE {} RESTART IDENTITY CASCADE", table_name);
144 sqlx::query(&truncate_sql).execute(&self.pool).await?;
145 }
146
147 Ok(())
148 }
149
150 pub async fn seed_from_json(&self, data: JsonValue) -> TestResult<()> {
152 if let Some(tables) = data.as_object() {
153 for (table_name, records) in tables {
154 if let Some(records_array) = records.as_array() {
155 for record in records_array {
156 self.insert_record(table_name, record).await?;
157 }
158 }
159 }
160 }
161 Ok(())
162 }
163
164 async fn insert_record(&self, table: &str, record: &JsonValue) -> TestResult<()> {
166 if let Some(fields) = record.as_object() {
167 let columns: Vec<String> = fields.keys().cloned().collect();
168 let placeholders: Vec<String> = (1..=columns.len()).map(|i| format!("${}", i)).collect();
169
170 let sql = format!(
171 "INSERT INTO {} ({}) VALUES ({})",
172 table,
173 columns.join(", "),
174 placeholders.join(", ")
175 );
176
177 let mut query = sqlx::query(&sql);
178 for column in &columns {
179 if let Some(value) = fields.get(column) {
180 match value {
181 JsonValue::String(s) => query = query.bind(s),
182 JsonValue::Number(n) => {
183 if let Some(i) = n.as_i64() {
184 query = query.bind(i);
185 } else if let Some(f) = n.as_f64() {
186 query = query.bind(f);
187 }
188 },
189 JsonValue::Bool(b) => query = query.bind(b),
190 JsonValue::Null => query = query.bind(Option::<String>::None),
191 _ => query = query.bind(value.to_string()),
192 }
193 }
194 }
195
196 query.execute(&self.pool).await?;
197 }
198 Ok(())
199 }
200
201 async fn ensure_test_database(pool: &PgPool) -> TestResult<()> {
203 sqlx::query("SELECT 1").fetch_one(pool).await?;
206 Ok(())
207 }
208}
209
210pub struct DatabaseTransaction {
212 transaction: Transaction<'static, Postgres>,
213}
214
215impl DatabaseTransaction {
216 pub async fn new(pool: &PgPool) -> TestResult<Self> {
218 let transaction = pool.begin().await?;
219 Ok(Self { transaction })
220 }
221
222 pub async fn execute(&mut self, sql: &str) -> TestResult<()> {
224 sqlx::query(sql).execute(&mut *self.transaction).await?;
225 Ok(())
226 }
227
228 pub async fn rollback(self) -> TestResult<()> {
230 self.transaction.rollback().await?;
231 Ok(())
232 }
233}
234
235pub struct DatabaseAssertions<'a> {
237 db: &'a TestDatabase,
238}
239
240impl<'a> DatabaseAssertions<'a> {
241 pub fn new(db: &'a TestDatabase) -> Self {
242 Self { db }
243 }
244
245 pub async fn assert_record_exists(&self, table: &str, conditions: &[(&str, &dyn ToString)]) -> TestResult<()> {
247 let exists = self.db.record_exists(table, conditions).await?;
248 if !exists {
249 let conditions_str = conditions.iter()
250 .map(|(k, v)| format!("{}={}", k, v.to_string()))
251 .collect::<Vec<_>>()
252 .join(", ");
253 return Err(TestError::Assertion {
254 message: format!("Expected record to exist in table '{}' with conditions: {}", table, conditions_str),
255 });
256 }
257 Ok(())
258 }
259
260 pub async fn assert_record_not_exists(&self, table: &str, conditions: &[(&str, &dyn ToString)]) -> TestResult<()> {
262 let exists = self.db.record_exists(table, conditions).await?;
263 if exists {
264 let conditions_str = conditions.iter()
265 .map(|(k, v)| format!("{}={}", k, v.to_string()))
266 .collect::<Vec<_>>()
267 .join(", ");
268 return Err(TestError::Assertion {
269 message: format!("Expected record to NOT exist in table '{}' with conditions: {}", table, conditions_str),
270 });
271 }
272 Ok(())
273 }
274
275 pub async fn assert_record_count(&self, table: &str, expected_count: i64, conditions: &[(&str, &dyn ToString)]) -> TestResult<()> {
277 let actual_count = self.db.count_records(table, conditions).await?;
278 if actual_count != expected_count {
279 return Err(TestError::Assertion {
280 message: format!("Expected {} records in table '{}', found {}", expected_count, table, actual_count),
281 });
282 }
283 Ok(())
284 }
285}
286
287#[macro_export]
289macro_rules! assert_database_has {
290 ($db:expr, $table:expr, $($field:expr => $value:expr),+) => {
291 {
292 let conditions: Vec<(&str, &dyn ToString)> = vec![
293 $(($field, $value),)+
294 ];
295 DatabaseAssertions::new($db).assert_record_exists($table, &conditions).await
296 }
297 };
298}
299
300#[macro_export]
301macro_rules! assert_database_count {
302 ($db:expr, $table:expr, $count:expr) => {
303 DatabaseAssertions::new($db).assert_record_count($table, $count, &[]).await
304 };
305 ($db:expr, $table:expr, $count:expr, $($field:expr => $value:expr),+) => {
306 {
307 let conditions: Vec<(&str, &dyn ToString)> = vec![
308 $(($field, $value),)+
309 ];
310 DatabaseAssertions::new($db).assert_record_count($table, $count, &conditions).await
311 }
312 };
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use serde_json::json;
319
320 #[tokio::test]
321 async fn test_database_utils() -> TestResult<()> {
322 let test_json = json!({
326 "users": [
327 {"id": 1, "name": "Test User", "email": "test@example.com"},
328 {"id": 2, "name": "Another User", "email": "another@example.com"}
329 ]
330 });
331
332 assert!(test_json.is_object());
333 if let Some(users) = test_json.get("users").and_then(|v| v.as_array()) {
334 assert_eq!(users.len(), 2);
335 }
336
337 Ok(())
338 }
339
340 #[test]
341 fn test_database_error_handling() {
342 let error = TestError::Database(sqlx::Error::RowNotFound);
343 assert!(matches!(error, TestError::Database(_)));
344 }
345
346 #[test]
347 fn test_assertion_error_creation() {
348 let error = TestError::Assertion {
349 message: "Test assertion failed".to_string(),
350 };
351 assert!(error.to_string().contains("Test assertion failed"));
352 }
353}