1use std::{path::Path, thread};
2
3use sqlx::{migrate::Migrator, query, Connection, Executor, PgConnection, PgPool};
4use tokio::runtime::Runtime;
5use uuid::Uuid;
6
7pub struct TestDb {
8 pub host: String,
9 pub port: u16,
10 pub user: String,
11 pub password: String,
12 pub dbname: String,
13}
14
15impl TestDb {
16 pub fn new(
17 host: impl Into<String>,
18 port: u16,
19 user: impl Into<String>,
20 password: impl Into<String>,
21 migration_path: impl Into<String>,
22 ) -> Self {
23 let host = host.into();
24 let user = user.into();
25 let password = password.into();
26
27 let uuid = Uuid::new_v4();
28 let dbname = format!("test_{}", uuid);
29 let dbname_cloned = dbname.clone();
30 let tdb = Self {
31 host,
32 port,
33 user,
34 password,
35 dbname,
36 };
37
38 let server_url = tdb.server_url();
39 let url = tdb.url();
40 let migration_path = migration_path.into();
41
42 thread::spawn(move || {
43 let rt = Runtime::new().unwrap();
44 rt.block_on(async move {
45 let mut conn = PgConnection::connect(&server_url).await.unwrap();
47 conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
48 .await
49 .expect("Error while querying the reservation database");
50
51 let mut conn = PgConnection::connect(&url).await.unwrap();
53 let m = Migrator::new(Path::new(&migration_path)).await.unwrap();
54 m.run(&mut conn).await.unwrap();
55 });
56 })
57 .join()
58 .expect("Error thread create database ");
59
60 tdb
61 }
62
63 pub fn server_url(&self) -> String {
64 if !self.password.is_empty() {
65 format!(
66 "postgres://{}:{}@{}:{}",
67 self.user, self.password, self.host, self.port
68 )
69 } else {
70 format!("postgres://{}@{}:{}", self.user, self.host, self.port)
71 }
72 }
73
74 pub fn url(&self) -> String {
75 format!("{}/{}", self.server_url(), self.dbname)
76 }
77
78 pub async fn get_pool(&self) -> PgPool {
79 sqlx::postgres::PgPoolOptions::new()
80 .max_connections(5)
81 .connect(&self.url())
82 .await
83 .unwrap()
84 }
85}
86
87impl Drop for TestDb {
88 fn drop(&mut self) {
89 let server_url = self.server_url().clone();
90 let dbname = self.dbname.clone();
91 thread::spawn(move || {
93 let rt= Runtime::new().unwrap();
94 rt.block_on(async move {
95 let mut conn = PgConnection::connect(&server_url).await.unwrap();
96 query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#,dbname))
98 .execute(&mut conn)
99 .await
100 .expect("Terminal all other connections");
101
102 conn.execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
103 .await
104 .expect("Error while querying the reservation database");
105 });
106 })
107 .join()
108 .expect("Error thread drop database ");
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use crate::TestDb;
115
116 #[tokio::test]
117 async fn test_db_should_create_and_drop() {
118 let tdb = TestDb::new("127.0.0.1", 5432, "zheng", "zz", "./migrations");
119 let pool = tdb.get_pool().await;
120 sqlx::query("INSERT INTO todos (title) VALUES ('test')")
122 .execute(&pool)
123 .await
124 .unwrap();
125 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
127 .fetch_one(&pool)
128 .await
129 .unwrap();
130
131 assert_eq!(id, 1);
132 assert_eq!(title, "test");
133 }
134}