sqlx_tester/
postgres.rs

1use std::{path::Path, thread};
2
3use sqlx::{
4    migrate::{MigrationSource, Migrator},
5    Connection, Executor, PgConnection, PgPool,
6};
7use tokio::runtime::Runtime;
8use uuid::Uuid;
9
10#[derive(Debug)]
11pub struct TestPg {
12    pub server_url: String,
13    pub dbname: String,
14}
15
16impl TestPg {
17    pub fn new<S>(server_url: String, migrations: S) -> Self
18    where
19        S: MigrationSource<'static> + Send + Sync + 'static,
20    {
21        let uuid = Uuid::new_v4();
22        let dbname = format!("test_{}", uuid);
23        let dbname_cloned = dbname.clone();
24
25        let tdb = Self { server_url, dbname };
26
27        let server_url = tdb.server_url();
28        let url = tdb.url();
29
30        // create database dbname
31        thread::spawn(move || {
32            let rt = Runtime::new().unwrap();
33            rt.block_on(async move {
34                // use server url to crate database
35                let mut conn = PgConnection::connect(&server_url).await.unwrap();
36                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
37                    .await
38                    .unwrap();
39
40                // now connect to test database for migration
41                let mut conn = PgConnection::connect(&url).await.unwrap();
42                let m = Migrator::new(migrations).await.unwrap();
43                m.run(&mut conn).await.unwrap();
44            });
45        })
46        .join()
47        .expect("failed to create database");
48
49        tdb
50    }
51
52    pub fn server_url(&self) -> String {
53        self.server_url.clone()
54    }
55
56    pub fn url(&self) -> String {
57        format!("{}/{}", self.server_url, self.dbname)
58    }
59
60    pub async fn get_pool(&self) -> PgPool {
61        PgPool::connect(&self.url()).await.unwrap()
62    }
63}
64
65impl Drop for TestPg {
66    fn drop(&mut self) {
67        let server_url = self.server_url();
68        let dbname = self.dbname.clone();
69        thread::spawn(move ||{
70            let rt = Runtime::new().unwrap();
71            rt.block_on(async move{
72                let mut conn = PgConnection::connect(&server_url).await.unwrap();
73                // terminate existing connections
74                sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#, dbname))
75                .execute(&mut conn)
76                .await
77                .expect("Terminate all other connections");
78                conn.execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
79                .await
80                .expect("Error while querying the drop database");
81            });
82        })
83        .join()
84        .expect("failed to drop database");
85    }
86}
87
88impl Default for TestPg {
89    fn default() -> Self {
90        Self::new(
91            "postgres://postgres:admin123@172.18.3.1:5432".to_string(),
92            Path::new("./migrations"),
93        )
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use crate::postgres::TestPg;
100
101    #[tokio::test]
102    async fn test_postgres_should_create_and_drop() {
103        let tdb = TestPg::default();
104        let pool = tdb.get_pool().await;
105        // insert todo
106        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
107            .execute(&pool)
108            .await
109            .unwrap();
110
111        // get todo
112        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
113            .fetch_one(&pool)
114            .await
115            .unwrap();
116
117        assert_eq!(id, 1);
118        assert_eq!(title, "test");
119    }
120}