tiberius_db_tester/
lib.rs

1use std::{fs::File, io::Read as _, thread};
2
3use tiberius::{Client, Query};
4use tokio::{net::TcpStream, runtime::Runtime};
5use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt as _};
6use uuid::Uuid;
7
8#[derive(Clone, Debug)]
9pub struct DBTester {
10    pub config: tiberius::Config,
11    pub db_name: String,
12}
13
14impl DBTester {
15    pub fn new(
16        host: impl ToString,
17        port: u16,
18        username: impl ToString,
19        password: impl ToString,
20        miration_file_path: impl ToString,
21    ) -> DBTester {
22        let mut config = tiberius::Config::new();
23        let db_name = format!("testdb_{}", Uuid::new_v4()).replace("-", "_");
24        config.host(host);
25        config.port(port);
26        config.authentication(tiberius::AuthMethod::sql_server(username, password));
27        config.trust_cert();
28        let config_clone = config.clone();
29        let db_name_clone = db_name.clone();
30        let dbt = DBTester { config, db_name };
31        let sql = read_sql_file(miration_file_path);
32
33        thread::spawn(move || {
34            let rt = Runtime::new().expect("Failed to create runtime");
35            rt.block_on(async move {
36                let tcp = TcpStream::connect(config_clone.get_addr())
37                    .await
38                    .expect("Failed to connect to SQL Server");
39                tcp.set_nodelay(true).expect("Failed to set TCP_NODELAY");
40
41                let mut client = Client::connect(config_clone.clone(), tcp.compat_write())
42                    .await
43                    .expect("Failed to create client");
44                let query = Query::new(format!("CREATE DATABASE {}", db_name_clone));
45                query
46                    .execute(&mut client)
47                    .await
48                    .expect("Failed to create database");
49
50                // 切换到新创建的数据库
51                let query = Query::new(format!("USE {}", db_name_clone));
52                query
53                    .execute(&mut client)
54                    .await
55                    .expect("Failed to switch to database");
56                // 执行迁移
57                let query = Query::new(sql);
58                query
59                    .execute(&mut client)
60                    .await
61                    .expect("Failed to execute query");
62            });
63        })
64        .join()
65        .expect("");
66        dbt
67    }
68    pub fn get_config(&self) -> tiberius::Config {
69        self.config.clone()
70    }
71    pub async fn get_client(&self) -> Client<Compat<TcpStream>> {
72        let mut config = self.config.clone();
73        config.database(self.db_name.clone());
74        let tcp = TcpStream::connect(config.get_addr())
75            .await
76            .expect("Failed to connect to SQL Server");
77        tcp.set_nodelay(true).expect("Failed to set TCP_NODELAY");
78
79        let client = Client::connect(config, tcp.compat_write())
80            .await
81            .expect("Failed to create client");
82        client
83    }
84}
85
86impl Drop for DBTester {
87    fn drop(&mut self) {
88        let config = self.config.clone();
89        let db_name = self.db_name.clone();
90        thread::spawn(move || {
91            let rt = Runtime::new().expect("Failed to create runtime");
92            rt.block_on(async move {
93                let tcp = TcpStream::connect(config.get_addr())
94                    .await
95                    .expect("Failed to connect to SQL Server");
96                tcp.set_nodelay(true).expect("Failed to set TCP_NODELAY");
97
98                let mut client = Client::connect(config.clone(), tcp.compat_write())
99                    .await
100                    .expect("Failed to create client");
101
102                // 首先切换到master数据库
103                let switch_query = Query::new("USE master");
104                switch_query
105                    .execute(&mut client)
106                    .await
107                    .expect("Failed to switch to master database");
108
109                // 关闭所有到目标数据库的连接
110                let kill_connections_query = Query::new(format!(
111                    "ALTER DATABASE {} SET SINGLE_USER WITH ROLLBACK IMMEDIATE",
112                    db_name
113                ));
114                kill_connections_query
115                    .execute(&mut client)
116                    .await
117                    .expect("Failed to kill database connections");
118
119                // 删除数据库
120                let query = Query::new(format!("DROP DATABASE {}", db_name));
121                query
122                    .execute(&mut client)
123                    .await
124                    .expect("Failed to drop database");
125            });
126        })
127        .join()
128        .expect("`Drop` should never panic");
129    }
130}
131
132fn read_sql_file(file_path: impl ToString) -> String {
133    let mut file = File::open(file_path.to_string()).expect("Failed to open file");
134    let mut content = String::new();
135    file.read_to_string(&mut content)
136        .expect("Failed to read file");
137    content
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[tokio::test]
145    async fn new_should_work() {
146        let dbt = DBTester::new("101.95.95.58", 1433, "sa", "Ibm123", "migrates/test.sql");
147        let mut client = dbt.get_client().await;
148        let query = Query::new("SELECT * FROM test");
149        let result = query.query(&mut client).await.unwrap();
150        let rows = result.into_first_result().await.unwrap();
151        assert_eq!(rows.len(), 2);
152    }
153
154    #[test]
155    fn read_sql_file_should_work() {
156        let sql = read_sql_file("migrates/test.sql");
157        assert!(!sql.is_empty());
158    }
159}