db_sqlx_tester/
lib.rs

1use std::{path::Path, thread};
2
3use sqlx::{migrate::Migrator, query, Connection, Executor, PgConnection, PgPool};
4use tokio::runtime::Runtime;
5use uuid::Uuid;
6
7pub struct TestDb {
8    pub host: String,
9    pub port: u16,
10    pub user: String,
11    pub password: String,
12    pub dbname: String,
13}
14
15impl TestDb {
16    pub fn new(
17        host: impl Into<String>,
18        port: u16,
19        user: impl Into<String>,
20        password: impl Into<String>,
21        migration_path: impl Into<String>,
22    ) -> Self {
23        let host = host.into();
24        let user = user.into();
25        let password = password.into();
26
27        let uuid = Uuid::new_v4();
28        let dbname = format!("test_{}", uuid);
29        let dbname_cloned = dbname.clone();
30        let tdb = Self {
31            host,
32            port,
33            user,
34            password,
35            dbname,
36        };
37
38        let server_url = tdb.server_url();
39        let url = tdb.url();
40        let migration_path = migration_path.into();
41
42        thread::spawn(move || {
43            let rt = Runtime::new().unwrap();
44            rt.block_on(async move {
45                // use server url to create database
46                let mut conn = PgConnection::connect(&server_url).await.unwrap();
47                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
48                    .await
49                    .expect("Error while querying the reservation database");
50
51                // now connect to test database for migration
52                let mut conn = PgConnection::connect(&url).await.unwrap();
53                let m = Migrator::new(Path::new(&migration_path)).await.unwrap();
54                m.run(&mut conn).await.unwrap();
55            });
56        })
57        .join()
58        .expect("Error thread create database ");
59
60        tdb
61    }
62
63    pub fn server_url(&self) -> String {
64        if !self.password.is_empty() {
65            format!(
66                "postgres://{}:{}@{}:{}",
67                self.user, self.password, self.host, self.port
68            )
69        } else {
70            format!("postgres://{}@{}:{}", self.user, self.host, self.port)
71        }
72    }
73
74    pub fn url(&self) -> String {
75        format!("{}/{}", self.server_url(), self.dbname)
76    }
77
78    pub async fn get_pool(&self) -> PgPool {
79        sqlx::postgres::PgPoolOptions::new()
80            .max_connections(5)
81            .connect(&self.url())
82            .await
83            .unwrap()
84    }
85}
86
87impl Drop for TestDb {
88    fn drop(&mut self) {
89        let server_url = self.server_url().clone();
90        let dbname = self.dbname.clone();
91        // drop 时删除数据库
92        thread::spawn(move || {
93            let rt= Runtime::new().unwrap();
94            rt.block_on(async move {
95                let mut conn = PgConnection::connect(&server_url).await.unwrap();
96                // terminate existing connection`中断现有连接
97                query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#,dbname))
98                .execute(&mut conn)
99                .await
100                .expect("Terminal all other connections");
101
102                conn.execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
103                    .await
104                    .expect("Error while querying the reservation database");
105            });
106        })
107        .join()
108        .expect("Error thread drop database ");
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use crate::TestDb;
115
116    #[tokio::test]
117    async fn test_db_should_create_and_drop() {
118        let tdb = TestDb::new("127.0.0.1", 5432, "zheng", "zz", "./migrations");
119        let pool = tdb.get_pool().await;
120        // insert todo
121        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
122            .execute(&pool)
123            .await
124            .unwrap();
125        // get todo
126        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
127            .fetch_one(&pool)
128            .await
129            .unwrap();
130
131        assert_eq!(id, 1);
132        assert_eq!(title, "test");
133    }
134}