tiberius_db_tester/
lib.rs1use std::{fs::File, io::Read as _, thread};
2
3use tiberius::{Client, Query};
4use tokio::{net::TcpStream, runtime::Runtime};
5use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt as _};
6use uuid::Uuid;
7
8#[derive(Clone, Debug)]
9pub struct DBTester {
10 pub config: tiberius::Config,
11 pub db_name: String,
12}
13
14impl DBTester {
15 pub fn new(
16 host: impl ToString,
17 port: u16,
18 username: impl ToString,
19 password: impl ToString,
20 miration_file_path: impl ToString,
21 ) -> DBTester {
22 let mut config = tiberius::Config::new();
23 let db_name = format!("testdb_{}", Uuid::new_v4()).replace("-", "_");
24 config.host(host);
25 config.port(port);
26 config.authentication(tiberius::AuthMethod::sql_server(username, password));
27 config.trust_cert();
28 let config_clone = config.clone();
29 let db_name_clone = db_name.clone();
30 let dbt = DBTester { config, db_name };
31 let sql = read_sql_file(miration_file_path);
32
33 thread::spawn(move || {
34 let rt = Runtime::new().expect("Failed to create runtime");
35 rt.block_on(async move {
36 let tcp = TcpStream::connect(config_clone.get_addr())
37 .await
38 .expect("Failed to connect to SQL Server");
39 tcp.set_nodelay(true).expect("Failed to set TCP_NODELAY");
40
41 let mut client = Client::connect(config_clone.clone(), tcp.compat_write())
42 .await
43 .expect("Failed to create client");
44 let query = Query::new(format!("CREATE DATABASE {}", db_name_clone));
45 query
46 .execute(&mut client)
47 .await
48 .expect("Failed to create database");
49
50 let query = Query::new(format!("USE {}", db_name_clone));
52 query
53 .execute(&mut client)
54 .await
55 .expect("Failed to switch to database");
56 let query = Query::new(sql);
58 query
59 .execute(&mut client)
60 .await
61 .expect("Failed to execute query");
62 });
63 })
64 .join()
65 .expect("");
66 dbt
67 }
68 pub fn get_config(&self) -> tiberius::Config {
69 self.config.clone()
70 }
71 pub async fn get_client(&self) -> Client<Compat<TcpStream>> {
72 let mut config = self.config.clone();
73 config.database(self.db_name.clone());
74 let tcp = TcpStream::connect(config.get_addr())
75 .await
76 .expect("Failed to connect to SQL Server");
77 tcp.set_nodelay(true).expect("Failed to set TCP_NODELAY");
78
79 let client = Client::connect(config, tcp.compat_write())
80 .await
81 .expect("Failed to create client");
82 client
83 }
84}
85
86impl Drop for DBTester {
87 fn drop(&mut self) {
88 let config = self.config.clone();
89 let db_name = self.db_name.clone();
90 thread::spawn(move || {
91 let rt = Runtime::new().expect("Failed to create runtime");
92 rt.block_on(async move {
93 let tcp = TcpStream::connect(config.get_addr())
94 .await
95 .expect("Failed to connect to SQL Server");
96 tcp.set_nodelay(true).expect("Failed to set TCP_NODELAY");
97
98 let mut client = Client::connect(config.clone(), tcp.compat_write())
99 .await
100 .expect("Failed to create client");
101
102 let switch_query = Query::new("USE master");
104 switch_query
105 .execute(&mut client)
106 .await
107 .expect("Failed to switch to master database");
108
109 let kill_connections_query = Query::new(format!(
111 "ALTER DATABASE {} SET SINGLE_USER WITH ROLLBACK IMMEDIATE",
112 db_name
113 ));
114 kill_connections_query
115 .execute(&mut client)
116 .await
117 .expect("Failed to kill database connections");
118
119 let query = Query::new(format!("DROP DATABASE {}", db_name));
121 query
122 .execute(&mut client)
123 .await
124 .expect("Failed to drop database");
125 });
126 })
127 .join()
128 .expect("`Drop` should never panic");
129 }
130}
131
132fn read_sql_file(file_path: impl ToString) -> String {
133 let mut file = File::open(file_path.to_string()).expect("Failed to open file");
134 let mut content = String::new();
135 file.read_to_string(&mut content)
136 .expect("Failed to read file");
137 content
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[tokio::test]
145 async fn new_should_work() {
146 let dbt = DBTester::new("101.95.95.58", 1433, "sa", "Ibm123", "migrates/test.sql");
147 let mut client = dbt.get_client().await;
148 let query = Query::new("SELECT * FROM test");
149 let result = query.query(&mut client).await.unwrap();
150 let rows = result.into_first_result().await.unwrap();
151 assert_eq!(rows.len(), 2);
152 }
153
154 #[test]
155 fn read_sql_file_should_work() {
156 let sql = read_sql_file("migrates/test.sql");
157 assert!(!sql.is_empty());
158 }
159}