rbdc_mssql/
lib.rs

1#![allow(mismatched_lifetime_syntaxes)]
2
3pub extern crate tiberius;
4
5pub mod decode;
6pub mod driver;
7pub mod encode;
8
9pub use crate::driver::MssqlDriver;
10pub use crate::driver::MssqlDriver as Driver;
11
12use crate::decode::Decode;
13use crate::encode::Encode;
14use futures_core::future::BoxFuture;
15use futures_core::Stream;
16use percent_encoding::percent_decode_str;
17use rbdc::db::{ConnectOptions, Connection, ExecResult, MetaData, Placeholder, Row};
18use rbdc::Error;
19use rbs::Value;
20use std::sync::Arc;
21use tiberius::{AuthMethod, Client, Column, ColumnData, Config, EncryptionLevel, Query};
22use tokio::net::TcpStream;
23use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
24use url::Url;
25
26pub struct MssqlConnection {
27    inner: Option<Client<Compat<TcpStream>>>,
28}
29
30impl MssqlConnection {
31    /// let cfg = Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_owned()))?;
32    pub async fn establish(cfg: &Config) -> Result<Self, Error> {
33        // let cfg = Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_owned()))?;
34        let tcp = TcpStream::connect(cfg.get_addr())
35            .await
36            .map_err(|e| Error::from(e.to_string()))?;
37        tcp.set_nodelay(true)?;
38        let c = Client::connect(cfg.clone(), tcp.compat_write())
39            .await
40            .map_err(|e| Error::from(e.to_string()))?;
41        Ok(Self { inner: Some(c) })
42    }
43}
44
45#[derive(Debug)]
46pub struct MssqlConnectOptions(pub Config);
47
48impl ConnectOptions for MssqlConnectOptions {
49    fn connect(&self) -> BoxFuture<'_, Result<Box<dyn Connection>, Error>> {
50        Box::pin(async move {
51            let v = MssqlConnection::establish(&self.0)
52                .await
53                .map_err(|e| Error::from(e.to_string()))?;
54            Ok(Box::new(v) as Box<dyn Connection>)
55        })
56    }
57
58    fn set_uri(&mut self, url: &str) -> Result<(), Error> {
59        if url.contains("jdbc") {
60            let mut config =
61                Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_string()))?;
62            config.trust_cert();
63            *self = MssqlConnectOptions(config);
64        } else if url.starts_with("mssql://") || url.starts_with("sqlserver://") {
65            let mut config = parse_url_connection_string(url)?;
66            config.trust_cert();
67            *self = MssqlConnectOptions(config);
68        } else {
69            let mut config =
70                Config::from_ado_string(url).map_err(|e| Error::from(e.to_string()))?;
71            config.trust_cert();
72            *self = MssqlConnectOptions(config);
73        }
74        Ok(())
75    }
76}
77
78/// Parse URL format connection string (mssql:// or sqlserver://)
79/// Format: mssql://user:password@host:port/database?param1=value1&param2=value2
80/// Or: sqlserver://user:password@host:port/database?param1=value1&param2=value2
81///
82/// Supported query parameters:
83/// - instance: SQL Server instance name
84/// - application_name: Application name
85/// - encrypt: Encryption level (true/false/DANGER_PLAINTEXT)
86/// - trust_cert: Whether to trust server certificate (true/false)
87/// - readonly: Read-only mode (true/false)
88fn parse_url_connection_string(url: &str) -> Result<Config, Error> {
89    let parsed_url = Url::parse(url).map_err(|e| Error::from(e.to_string()))?;
90
91    let mut config = Config::new();
92
93    // Set host
94    if let Some(host) = parsed_url.host_str() {
95        config.host(host.to_string());
96    }
97
98    // Set port
99    if let Some(port) = parsed_url.port() {
100        config.port(port);
101    }
102
103    // Set username and password
104    let username = parsed_url.username();
105    if !username.is_empty() {
106        let decoded_username = percent_decode_str(username)
107            .decode_utf8()
108            .map_err(|e| Error::from(e.to_string()))?;
109
110        if let Some(password) = parsed_url.password() {
111            let decoded_password = percent_decode_str(password)
112                .decode_utf8()
113                .map_err(|e| Error::from(e.to_string()))?;
114            config.authentication(AuthMethod::sql_server(&decoded_username, &decoded_password));
115        } else {
116            config.authentication(AuthMethod::sql_server(&decoded_username, ""));
117        }
118    }
119
120    // Set database
121    let path = parsed_url.path().trim_start_matches('/');
122    if !path.is_empty() {
123        config.database(path);
124    }
125
126    // Parse query parameters
127    for (key, value) in parsed_url.query_pairs() {
128        match key.to_lowercase().as_str() {
129            "instance" | "instance_name" => {
130                config.instance_name(&*value);
131            }
132            "application_name" | "applicationname" => {
133                config.application_name(&*value);
134            }
135            "encrypt" | "encryption" => match value.to_lowercase().as_str() {
136                "true" | "yes" => {
137                    #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
138                    config.encryption(EncryptionLevel::Required);
139                }
140                "false" | "no" => {
141                    #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
142                    config.encryption(EncryptionLevel::Off);
143                }
144                "danger_plaintext" => {
145                    config.encryption(EncryptionLevel::NotSupported);
146                }
147                _ => {
148                    return Err(Error::from(format!("Invalid encryption value: {}", value)));
149                }
150            },
151            "trust_cert" | "trustservercertificate" => {
152                match value.to_lowercase().as_str() {
153                    "true" | "yes" => {
154                        config.trust_cert();
155                    }
156                    "false" | "no" => {
157                        // Default behavior, no special handling needed
158                    }
159                    _ => {
160                        return Err(Error::from(format!("Invalid trust_cert value: {}", value)));
161                    }
162                }
163            }
164            "readonly" | "applicationintent" => match value.to_lowercase().as_str() {
165                "true" | "yes" | "readonly" => {
166                    config.readonly(true);
167                }
168                "false" | "no" | "readwrite" => {
169                    config.readonly(false);
170                }
171                _ => {
172                    return Err(Error::from(format!("Invalid readonly value: {}", value)));
173                }
174            },
175            _ => {
176                // Ignore unknown parameters
177            }
178        }
179    }
180
181    Ok(config)
182}
183
184#[derive(Debug)]
185pub struct MssqlRow {
186    pub columns: Arc<Vec<Column>>,
187    pub datas: Vec<ColumnData<'static>>,
188}
189
190#[derive(Debug)]
191pub struct MssqlMetaData(pub Arc<Vec<Column>>);
192
193impl MetaData for MssqlMetaData {
194    fn column_len(&self) -> usize {
195        self.0.len()
196    }
197
198    fn column_name(&self, i: usize) -> String {
199        self.0[i].name().to_string()
200    }
201
202    fn column_type(&self, i: usize) -> String {
203        format!("{:?}", self.0[i].column_type())
204    }
205}
206
207impl Row for MssqlRow {
208    fn meta_data(&self) -> Box<dyn MetaData> {
209        Box::new(MssqlMetaData(self.columns.clone()))
210    }
211
212    fn get(&mut self, i: usize) -> Result<Value, Error> {
213        Value::decode(&self.datas[i])
214    }
215}
216
217impl Connection for MssqlConnection {
218    fn get_rows(
219        &mut self,
220        sql: &str,
221        params: Vec<Value>,
222    ) -> BoxFuture<'_, Result<Vec<Box<dyn Row>>, Error>> {
223        let sql = MssqlDriver {}.exchange(sql);
224        Box::pin(async move {
225            let mut q = Query::new(sql);
226            for x in params {
227                x.encode(&mut q)?;
228            }
229            let v = q
230                .query(
231                    self.inner
232                        .as_mut()
233                        .ok_or_else(|| Error::from("MssqlConnection is close"))?,
234                )
235                .await
236                .map_err(|e| Error::from(e.to_string()))?;
237            let mut results = Vec::with_capacity(v.size_hint().0);
238            let s = v
239                .into_results()
240                .await
241                .map_err(|e| Error::from(e.to_string()))?;
242            for item in s {
243                for r in item {
244                    let mut columns = Vec::with_capacity(r.columns().len());
245                    let mut row = MssqlRow {
246                        columns: Arc::new(vec![]),
247                        datas: Vec::with_capacity(r.columns().len()),
248                    };
249                    for x in r.columns() {
250                        columns.push(x.clone());
251                    }
252                    row.columns = Arc::new(columns);
253                    for x in r {
254                        row.datas.push(x);
255                    }
256                    results.push(Box::new(row) as Box<dyn Row>);
257                }
258            }
259            Ok(results)
260        })
261    }
262
263    fn exec(&mut self, sql: &str, params: Vec<Value>) -> BoxFuture<'_, Result<ExecResult, Error>> {
264        let sql = MssqlDriver {}.exchange(sql);
265        Box::pin(async move {
266            let mut q = Query::new(sql);
267            for x in params {
268                x.encode(&mut q)?;
269            }
270            let v = q
271                .execute(
272                    self.inner
273                        .as_mut()
274                        .ok_or_else(|| Error::from("MssqlConnection is close"))?,
275                )
276                .await
277                .map_err(|e| Error::from(e.to_string()))?;
278            Ok(ExecResult {
279                rows_affected: {
280                    let mut rows_affected = 0;
281                    for x in v.rows_affected() {
282                        rows_affected += x.clone();
283                    }
284                    rows_affected
285                },
286                last_insert_id: Value::Null,
287            })
288        })
289    }
290
291    fn close(&mut self) -> BoxFuture<'_, Result<(), Error>> {
292        Box::pin(async move {
293            //inner must be Option,so we can take owner and call close(self) method.
294            if let Some(v) = self.inner.take() {
295                v.close().await.map_err(|e| Error::from(e.to_string()))?;
296            }
297            Ok(())
298        })
299    }
300
301    fn ping(&mut self) -> BoxFuture<'_, Result<(), rbdc::Error>> {
302        //TODO While 'select 1' can temporarily solve the problem of checking that the connection is valid, it looks ugly.Better replace it with something better way
303        Box::pin(async move {
304            self.inner
305                .as_mut()
306                .ok_or_else(|| Error::from("MssqlConnection is close"))?
307                .query("select 1", &[])
308                .await
309                .map_err(|e| Error::from(e.to_string()))?;
310            Ok(())
311        })
312    }
313
314    fn begin(&mut self) -> BoxFuture<'_, Result<(), Error>> {
315        Box::pin(async move {
316            self.inner
317                .as_mut()
318                .ok_or_else(|| Error::from("MssqlConnection is close"))?
319                .simple_query("begin tran")
320                .await
321                .map_err(|e| Error::from(e.to_string()))?;
322            Ok(())
323        })
324    }
325
326    fn commit(&mut self) -> BoxFuture<'_, Result<(), Error>> {
327        Box::pin(async move {
328            self.inner
329                .as_mut()
330                .ok_or_else(|| Error::from("MssqlConnection is close"))?
331                .simple_query("commit")
332                .await
333                .map_err(|e| Error::from(e.to_string()))?;
334            Ok(())
335        })
336    }
337
338    fn rollback(&mut self) -> BoxFuture<'_, Result<(), Error>> {
339        Box::pin(async move {
340            self.inner
341                .as_mut()
342                .ok_or_else(|| Error::from("MssqlConnection is close"))?
343                .simple_query("rollback")
344                .await
345                .map_err(|e| Error::from(e.to_string()))?;
346            Ok(())
347        })
348    }
349}
350
351#[cfg(test)]
352mod test {
353    use crate::{parse_url_connection_string, MssqlConnectOptions};
354    use rbdc::db::ConnectOptions;
355    use tiberius::Config;
356
357    #[test]
358    fn test_datetime() {}
359
360    #[test]
361    fn test_connection_string_parsing() {
362        // 测试 JDBC 格式
363        let jdbc_uri =
364            "jdbc:sqlserver://localhost:1433;User=SA;Password={TestPass!123456};Database=master;";
365        let mut options = MssqlConnectOptions(Config::new());
366        let result = options.set_uri(jdbc_uri);
367        assert!(result.is_ok(), "JDBC format should be supported");
368
369        // 测试 mssql:// 格式
370        let mssql_uri = "mssql://SA:TestPass!123456@localhost:1433/master";
371        let mut options = MssqlConnectOptions(Config::new());
372        let result = options.set_uri(mssql_uri);
373        assert!(
374            result.is_ok(),
375            "mssql:// format should be supported: {:?}",
376            result
377        );
378
379        // 测试 sqlserver:// 格式
380        let sqlserver_uri = "sqlserver://SA:TestPass!123456@localhost:1433/master";
381        let mut options = MssqlConnectOptions(Config::new());
382        let result = options.set_uri(sqlserver_uri);
383        assert!(
384            result.is_ok(),
385            "sqlserver:// format should be supported: {:?}",
386            result
387        );
388
389        // 测试 ADO 格式
390        let ado_uri = "Server=localhost,1433;User Id=SA;Password=TestPass!123456;Database=master;";
391        let mut options = MssqlConnectOptions(Config::new());
392        let result = options.set_uri(ado_uri);
393        assert!(result.is_ok(), "ADO format should be supported");
394    }
395
396    #[test]
397    fn test_url_parsing_details() {
398        // 测试详细的 URL 解析
399        let config =
400            parse_url_connection_string("mssql://testuser:testpass@example.com:1433/testdb")
401                .unwrap();
402        assert_eq!(config.get_addr(), "example.com:1433");
403
404        // 测试没有密码的情况
405        let config = parse_url_connection_string("mssql://testuser@localhost:1433/testdb").unwrap();
406        assert_eq!(config.get_addr(), "localhost:1433");
407
408        // 测试没有数据库的情况
409        let config =
410            parse_url_connection_string("mssql://testuser:testpass@localhost:1433").unwrap();
411        assert_eq!(config.get_addr(), "localhost:1433");
412
413        // 测试默认端口
414        let config =
415            parse_url_connection_string("mssql://testuser:testpass@localhost/testdb").unwrap();
416        assert_eq!(config.get_addr(), "localhost:1433");
417    }
418
419    #[test]
420    fn test_url_query_parameters() {
421        // 测试带查询参数的 URL
422        let config = parse_url_connection_string(
423            "mssql://testuser:testpass@localhost:1433/testdb?instance=SQLEXPRESS&application_name=MyApp&encrypt=true&trust_cert=true&readonly=true"
424        ).unwrap();
425        assert_eq!(config.get_addr(), "localhost:1433");
426
427        // 测试部分查询参数
428        let config = parse_url_connection_string(
429            "sqlserver://user:pass@server:1433/db?application_name=TestApp&encrypt=false",
430        )
431        .unwrap();
432        assert_eq!(config.get_addr(), "server:1433");
433
434        // 测试无效的加密值应该返回错误
435        let result = parse_url_connection_string("mssql://user:pass@localhost/db?encrypt=invalid");
436        assert!(result.is_err());
437
438        // 测试无效的 trust_cert 值应该返回错误
439        let result =
440            parse_url_connection_string("mssql://user:pass@localhost/db?trust_cert=invalid");
441        assert!(result.is_err());
442    }
443}