rbdc_mssql/
lib.rs

1#![allow(mismatched_lifetime_syntaxes)]
2
3pub extern crate tiberius;
4
5pub mod decode;
6pub mod driver;
7pub mod encode;
8
9pub use crate::driver::MssqlDriver;
10pub use crate::driver::MssqlDriver as Driver;
11
12use crate::decode::Decode;
13use crate::encode::Encode;
14use futures_core::future::BoxFuture;
15use futures_core::Stream;
16use rbdc::db::{ConnectOptions, Connection, ExecResult, MetaData, Placeholder, Row};
17use rbdc::Error;
18use rbs::Value;
19use std::sync::Arc;
20use tiberius::{AuthMethod, Client, Column, ColumnData, Config, EncryptionLevel, Query};
21use tokio::net::TcpStream;
22use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
23use url::Url;
24use percent_encoding::percent_decode_str;
25
26pub struct MssqlConnection {
27    inner: Option<Client<Compat<TcpStream>>>,
28}
29
30impl MssqlConnection {
31    /// let cfg = Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_owned()))?;
32    pub async fn establish(cfg: &Config) -> Result<Self, Error> {
33        // let cfg = Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_owned()))?;
34        let tcp = TcpStream::connect(cfg.get_addr())
35            .await
36            .map_err(|e| Error::from(e.to_string()))?;
37        tcp.set_nodelay(true)?;
38        let c = Client::connect(cfg.clone(), tcp.compat_write())
39            .await
40            .map_err(|e| Error::from(e.to_string()))?;
41        Ok(Self { inner: Some(c) })
42    }
43}
44
45#[derive(Debug)]
46pub struct MssqlConnectOptions(pub Config);
47
48impl ConnectOptions for MssqlConnectOptions {
49    fn connect(&self) -> BoxFuture<Result<Box<dyn Connection>, Error>> {
50        Box::pin(async move {
51            let v = MssqlConnection::establish(&self.0)
52                .await
53                .map_err(|e| Error::from(e.to_string()))?;
54            Ok(Box::new(v) as Box<dyn Connection>)
55        })
56    }
57
58    fn set_uri(&mut self, url: &str) -> Result<(), Error> {
59        if url.contains("jdbc") {
60            let mut config = Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_string()))?;
61            config.trust_cert();
62            *self = MssqlConnectOptions(config);
63        } else if url.starts_with("mssql://") || url.starts_with("sqlserver://") {
64            let mut config = parse_url_connection_string(url)?;
65            config.trust_cert();
66            *self = MssqlConnectOptions(config);
67        } else {
68            let mut config = Config::from_ado_string(url).map_err(|e| Error::from(e.to_string()))?;
69            config.trust_cert();
70            *self = MssqlConnectOptions(config);
71        }
72        Ok(())
73    }
74}
75
76/// Parse URL format connection string (mssql:// or sqlserver://)
77/// Format: mssql://user:password@host:port/database?param1=value1&param2=value2
78/// Or: sqlserver://user:password@host:port/database?param1=value1&param2=value2
79///
80/// Supported query parameters:
81/// - instance: SQL Server instance name
82/// - application_name: Application name
83/// - encrypt: Encryption level (true/false/DANGER_PLAINTEXT)
84/// - trust_cert: Whether to trust server certificate (true/false)
85/// - readonly: Read-only mode (true/false)
86fn parse_url_connection_string(url: &str) -> Result<Config, Error> {
87    let parsed_url = Url::parse(url).map_err(|e| Error::from(e.to_string()))?;
88
89    let mut config = Config::new();
90
91    // Set host
92    if let Some(host) = parsed_url.host_str() {
93        config.host(host.to_string());
94    }
95
96    // Set port
97    if let Some(port) = parsed_url.port() {
98        config.port(port);
99    }
100
101    // Set username and password
102    let username = parsed_url.username();
103    if !username.is_empty() {
104        let decoded_username = percent_decode_str(username)
105            .decode_utf8()
106            .map_err(|e| Error::from(e.to_string()))?;
107
108        if let Some(password) = parsed_url.password() {
109            let decoded_password = percent_decode_str(password)
110                .decode_utf8()
111                .map_err(|e| Error::from(e.to_string()))?;
112            config.authentication(AuthMethod::sql_server(&decoded_username, &decoded_password));
113        } else {
114            config.authentication(AuthMethod::sql_server(&decoded_username, ""));
115        }
116    }
117
118    // Set database
119    let path = parsed_url.path().trim_start_matches('/');
120    if !path.is_empty() {
121        config.database(path);
122    }
123
124    // Parse query parameters
125    for (key, value) in parsed_url.query_pairs() {
126        match key.to_lowercase().as_str() {
127            "instance" | "instance_name" => {
128                config.instance_name(&*value);
129            }
130            "application_name" | "applicationname" => {
131                config.application_name(&*value);
132            }
133            "encrypt" | "encryption" => {
134                match value.to_lowercase().as_str() {
135                    "true" | "yes" => {
136                        #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
137                        config.encryption(EncryptionLevel::Required);
138                    }
139                    "false" | "no" => {
140                        #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
141                        config.encryption(EncryptionLevel::Off);
142                    }
143                    "danger_plaintext" => {
144                        config.encryption(EncryptionLevel::NotSupported);
145                    }
146                    _ => {
147                        return Err(Error::from(format!("Invalid encryption value: {}", value)));
148                    }
149                }
150            }
151            "trust_cert" | "trustservercertificate" => {
152                match value.to_lowercase().as_str() {
153                    "true" | "yes" => {
154                        config.trust_cert();
155                    }
156                    "false" | "no" => {
157                        // Default behavior, no special handling needed
158                    }
159                    _ => {
160                        return Err(Error::from(format!("Invalid trust_cert value: {}", value)));
161                    }
162                }
163            }
164            "readonly" | "applicationintent" => {
165                match value.to_lowercase().as_str() {
166                    "true" | "yes" | "readonly" => {
167                        config.readonly(true);
168                    }
169                    "false" | "no" | "readwrite" => {
170                        config.readonly(false);
171                    }
172                    _ => {
173                        return Err(Error::from(format!("Invalid readonly value: {}", value)));
174                    }
175                }
176            }
177            _ => {
178                // Ignore unknown parameters
179            }
180        }
181    }
182
183    Ok(config)
184}
185
186#[derive(Debug)]
187pub struct MssqlRow {
188    pub columns: Arc<Vec<Column>>,
189    pub datas: Vec<ColumnData<'static>>,
190}
191
192#[derive(Debug)]
193pub struct MssqlMetaData(pub Arc<Vec<Column>>);
194
195impl MetaData for MssqlMetaData {
196    fn column_len(&self) -> usize {
197        self.0.len()
198    }
199
200    fn column_name(&self, i: usize) -> String {
201        self.0[i].name().to_string()
202    }
203
204    fn column_type(&self, i: usize) -> String {
205        format!("{:?}", self.0[i].column_type())
206    }
207}
208
209impl Row for MssqlRow {
210    fn meta_data(&self) -> Box<dyn MetaData> {
211        Box::new(MssqlMetaData(self.columns.clone()))
212    }
213
214    fn get(&mut self, i: usize) -> Result<Value, Error> {
215        Value::decode(&self.datas[i])
216    }
217}
218
219impl Connection for MssqlConnection {
220    fn get_rows(
221        &mut self,
222        sql: &str,
223        params: Vec<Value>,
224    ) -> BoxFuture<Result<Vec<Box<dyn Row>>, Error>> {
225        let sql = MssqlDriver {}.exchange(sql);
226        Box::pin(async move {
227            let mut q = Query::new(sql);
228            for x in params {
229                x.encode(&mut q)?;
230            }
231            let v = q
232                .query(
233                    self.inner
234                        .as_mut()
235                        .ok_or_else(|| Error::from("MssqlConnection is close"))?,
236                )
237                .await
238                .map_err(|e| Error::from(e.to_string()))?;
239            let mut results = Vec::with_capacity(v.size_hint().0);
240            let s = v
241                .into_results()
242                .await
243                .map_err(|e| Error::from(e.to_string()))?;
244            for item in s {
245                for r in item {
246                    let mut columns = Vec::with_capacity(r.columns().len());
247                    let mut row = MssqlRow {
248                        columns: Arc::new(vec![]),
249                        datas: Vec::with_capacity(r.columns().len()),
250                    };
251                    for x in r.columns() {
252                        columns.push(x.clone());
253                    }
254                    row.columns = Arc::new(columns);
255                    for x in r {
256                        row.datas.push(x);
257                    }
258                    results.push(Box::new(row) as Box<dyn Row>);
259                }
260            }
261            Ok(results)
262        })
263    }
264
265    fn exec(&mut self, sql: &str, params: Vec<Value>) -> BoxFuture<Result<ExecResult, Error>> {
266        let sql = MssqlDriver {}.exchange(sql);
267        Box::pin(async move {
268            let mut q = Query::new(sql);
269            for x in params {
270                x.encode(&mut q)?;
271            }
272            let v = q
273                .execute(
274                    self.inner
275                        .as_mut()
276                        .ok_or_else(|| Error::from("MssqlConnection is close"))?,
277                )
278                .await
279                .map_err(|e| Error::from(e.to_string()))?;
280            Ok(ExecResult {
281                rows_affected: {
282                    let mut rows_affected = 0;
283                    for x in v.rows_affected() {
284                        rows_affected += x.clone();
285                    }
286                    rows_affected
287                },
288                last_insert_id: Value::Null,
289            })
290        })
291    }
292
293    fn close(&mut self) -> BoxFuture<Result<(), Error>> {
294        Box::pin(async move {
295            //inner must be Option,so we can take owner and call close(self) method.
296            if let Some(v) = self.inner.take() {
297                v.close().await.map_err(|e| Error::from(e.to_string()))?;
298            }
299            Ok(())
300        })
301    }
302
303    fn ping(&mut self) -> BoxFuture<Result<(), rbdc::Error>> {
304        //TODO While 'select 1' can temporarily solve the problem of checking that the connection is valid, it looks ugly.Better replace it with something better way
305        Box::pin(async move {
306            self.inner
307                .as_mut()
308                .ok_or_else(|| Error::from("MssqlConnection is close"))?
309                .query("select 1", &[])
310                .await
311                .map_err(|e| Error::from(e.to_string()))?;
312            Ok(())
313        })
314    }
315
316    fn begin(&mut self) -> BoxFuture<Result<(), Error>> {
317        Box::pin(async move {
318            self.inner
319                .as_mut()
320                .ok_or_else(|| Error::from("MssqlConnection is close"))?
321                .simple_query("begin tran")
322                .await
323                .map_err(|e| Error::from(e.to_string()))?;
324            Ok(())
325        })
326    }
327
328    fn commit(&mut self) -> BoxFuture<Result<(), Error>> {
329        Box::pin(async move {
330            self.inner
331                .as_mut()
332                .ok_or_else(|| Error::from("MssqlConnection is close"))?
333                .simple_query("commit")
334                .await
335                .map_err(|e| Error::from(e.to_string()))?;
336            Ok(())
337        })
338    }
339
340    fn rollback(&mut self) -> BoxFuture<Result<(), Error>> {
341        Box::pin(async move {
342            self.inner
343                .as_mut()
344                .ok_or_else(|| Error::from("MssqlConnection is close"))?
345                .simple_query("rollback")
346                .await
347                .map_err(|e| Error::from(e.to_string()))?;
348            Ok(())
349        })
350    }
351}
352
353#[cfg(test)]
354mod test {
355    use crate::driver::MssqlDriver;
356    use crate::{MssqlConnectOptions, parse_url_connection_string};
357    use rbdc::db::{Driver, ConnectOptions};
358    use tiberius::Config;
359
360    #[test]
361    fn test_datetime() {}
362
363    #[test]
364    fn test_connection_string_parsing() {
365        // 测试 JDBC 格式
366        let jdbc_uri = "jdbc:sqlserver://localhost:1433;User=SA;Password={TestPass!123456};Database=master;";
367        let mut options = MssqlConnectOptions(Config::new());
368        let result = options.set_uri(jdbc_uri);
369        assert!(result.is_ok(), "JDBC format should be supported");
370
371        // 测试 mssql:// 格式
372        let mssql_uri = "mssql://SA:TestPass!123456@localhost:1433/master";
373        let mut options = MssqlConnectOptions(Config::new());
374        let result = options.set_uri(mssql_uri);
375        assert!(result.is_ok(), "mssql:// format should be supported: {:?}", result);
376
377        // 测试 sqlserver:// 格式
378        let sqlserver_uri = "sqlserver://SA:TestPass!123456@localhost:1433/master";
379        let mut options = MssqlConnectOptions(Config::new());
380        let result = options.set_uri(sqlserver_uri);
381        assert!(result.is_ok(), "sqlserver:// format should be supported: {:?}", result);
382
383        // 测试 ADO 格式
384        let ado_uri = "Server=localhost,1433;User Id=SA;Password=TestPass!123456;Database=master;";
385        let mut options = MssqlConnectOptions(Config::new());
386        let result = options.set_uri(ado_uri);
387        assert!(result.is_ok(), "ADO format should be supported");
388    }
389
390    #[test]
391    fn test_url_parsing_details() {
392        // 测试详细的 URL 解析
393        let config = parse_url_connection_string("mssql://testuser:testpass@example.com:1433/testdb").unwrap();
394        assert_eq!(config.get_addr(), "example.com:1433");
395
396        // 测试没有密码的情况
397        let config = parse_url_connection_string("mssql://testuser@localhost:1433/testdb").unwrap();
398        assert_eq!(config.get_addr(), "localhost:1433");
399
400        // 测试没有数据库的情况
401        let config = parse_url_connection_string("mssql://testuser:testpass@localhost:1433").unwrap();
402        assert_eq!(config.get_addr(), "localhost:1433");
403
404        // 测试默认端口
405        let config = parse_url_connection_string("mssql://testuser:testpass@localhost/testdb").unwrap();
406        assert_eq!(config.get_addr(), "localhost:1433");
407    }
408
409    #[test]
410    fn test_url_query_parameters() {
411        // 测试带查询参数的 URL
412        let config = parse_url_connection_string(
413            "mssql://testuser:testpass@localhost:1433/testdb?instance=SQLEXPRESS&application_name=MyApp&encrypt=true&trust_cert=true&readonly=true"
414        ).unwrap();
415        assert_eq!(config.get_addr(), "localhost:1433");
416
417        // 测试部分查询参数
418        let config = parse_url_connection_string(
419            "sqlserver://user:pass@server:1433/db?application_name=TestApp&encrypt=false"
420        ).unwrap();
421        assert_eq!(config.get_addr(), "server:1433");
422
423        // 测试无效的加密值应该返回错误
424        let result = parse_url_connection_string(
425            "mssql://user:pass@localhost/db?encrypt=invalid"
426        );
427        assert!(result.is_err());
428
429        // 测试无效的 trust_cert 值应该返回错误
430        let result = parse_url_connection_string(
431            "mssql://user:pass@localhost/db?trust_cert=invalid"
432        );
433        assert!(result.is_err());
434    }
435}