1#![allow(mismatched_lifetime_syntaxes)]
2
3pub extern crate tiberius;
4
5pub mod decode;
6pub mod driver;
7pub mod encode;
8
9pub use crate::driver::MssqlDriver;
10pub use crate::driver::MssqlDriver as Driver;
11
12use crate::decode::Decode;
13use crate::encode::Encode;
14use futures_core::future::BoxFuture;
15use futures_core::Stream;
16use rbdc::db::{ConnectOptions, Connection, ExecResult, MetaData, Placeholder, Row};
17use rbdc::Error;
18use rbs::Value;
19use std::sync::Arc;
20use tiberius::{AuthMethod, Client, Column, ColumnData, Config, EncryptionLevel, Query};
21use tokio::net::TcpStream;
22use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
23use url::Url;
24use percent_encoding::percent_decode_str;
25
26pub struct MssqlConnection {
27 inner: Option<Client<Compat<TcpStream>>>,
28}
29
30impl MssqlConnection {
31 pub async fn establish(cfg: &Config) -> Result<Self, Error> {
33 let tcp = TcpStream::connect(cfg.get_addr())
35 .await
36 .map_err(|e| Error::from(e.to_string()))?;
37 tcp.set_nodelay(true)?;
38 let c = Client::connect(cfg.clone(), tcp.compat_write())
39 .await
40 .map_err(|e| Error::from(e.to_string()))?;
41 Ok(Self { inner: Some(c) })
42 }
43}
44
45#[derive(Debug)]
46pub struct MssqlConnectOptions(pub Config);
47
48impl ConnectOptions for MssqlConnectOptions {
49 fn connect(&self) -> BoxFuture<Result<Box<dyn Connection>, Error>> {
50 Box::pin(async move {
51 let v = MssqlConnection::establish(&self.0)
52 .await
53 .map_err(|e| Error::from(e.to_string()))?;
54 Ok(Box::new(v) as Box<dyn Connection>)
55 })
56 }
57
58 fn set_uri(&mut self, url: &str) -> Result<(), Error> {
59 if url.contains("jdbc") {
60 let mut config = Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_string()))?;
61 config.trust_cert();
62 *self = MssqlConnectOptions(config);
63 } else if url.starts_with("mssql://") || url.starts_with("sqlserver://") {
64 let mut config = parse_url_connection_string(url)?;
65 config.trust_cert();
66 *self = MssqlConnectOptions(config);
67 } else {
68 let mut config = Config::from_ado_string(url).map_err(|e| Error::from(e.to_string()))?;
69 config.trust_cert();
70 *self = MssqlConnectOptions(config);
71 }
72 Ok(())
73 }
74}
75
76fn parse_url_connection_string(url: &str) -> Result<Config, Error> {
87 let parsed_url = Url::parse(url).map_err(|e| Error::from(e.to_string()))?;
88
89 let mut config = Config::new();
90
91 if let Some(host) = parsed_url.host_str() {
93 config.host(host.to_string());
94 }
95
96 if let Some(port) = parsed_url.port() {
98 config.port(port);
99 }
100
101 let username = parsed_url.username();
103 if !username.is_empty() {
104 let decoded_username = percent_decode_str(username)
105 .decode_utf8()
106 .map_err(|e| Error::from(e.to_string()))?;
107
108 if let Some(password) = parsed_url.password() {
109 let decoded_password = percent_decode_str(password)
110 .decode_utf8()
111 .map_err(|e| Error::from(e.to_string()))?;
112 config.authentication(AuthMethod::sql_server(&decoded_username, &decoded_password));
113 } else {
114 config.authentication(AuthMethod::sql_server(&decoded_username, ""));
115 }
116 }
117
118 let path = parsed_url.path().trim_start_matches('/');
120 if !path.is_empty() {
121 config.database(path);
122 }
123
124 for (key, value) in parsed_url.query_pairs() {
126 match key.to_lowercase().as_str() {
127 "instance" | "instance_name" => {
128 config.instance_name(&*value);
129 }
130 "application_name" | "applicationname" => {
131 config.application_name(&*value);
132 }
133 "encrypt" | "encryption" => {
134 match value.to_lowercase().as_str() {
135 "true" | "yes" => {
136 #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
137 config.encryption(EncryptionLevel::Required);
138 }
139 "false" | "no" => {
140 #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
141 config.encryption(EncryptionLevel::Off);
142 }
143 "danger_plaintext" => {
144 config.encryption(EncryptionLevel::NotSupported);
145 }
146 _ => {
147 return Err(Error::from(format!("Invalid encryption value: {}", value)));
148 }
149 }
150 }
151 "trust_cert" | "trustservercertificate" => {
152 match value.to_lowercase().as_str() {
153 "true" | "yes" => {
154 config.trust_cert();
155 }
156 "false" | "no" => {
157 }
159 _ => {
160 return Err(Error::from(format!("Invalid trust_cert value: {}", value)));
161 }
162 }
163 }
164 "readonly" | "applicationintent" => {
165 match value.to_lowercase().as_str() {
166 "true" | "yes" | "readonly" => {
167 config.readonly(true);
168 }
169 "false" | "no" | "readwrite" => {
170 config.readonly(false);
171 }
172 _ => {
173 return Err(Error::from(format!("Invalid readonly value: {}", value)));
174 }
175 }
176 }
177 _ => {
178 }
180 }
181 }
182
183 Ok(config)
184}
185
186#[derive(Debug)]
187pub struct MssqlRow {
188 pub columns: Arc<Vec<Column>>,
189 pub datas: Vec<ColumnData<'static>>,
190}
191
192#[derive(Debug)]
193pub struct MssqlMetaData(pub Arc<Vec<Column>>);
194
195impl MetaData for MssqlMetaData {
196 fn column_len(&self) -> usize {
197 self.0.len()
198 }
199
200 fn column_name(&self, i: usize) -> String {
201 self.0[i].name().to_string()
202 }
203
204 fn column_type(&self, i: usize) -> String {
205 format!("{:?}", self.0[i].column_type())
206 }
207}
208
209impl Row for MssqlRow {
210 fn meta_data(&self) -> Box<dyn MetaData> {
211 Box::new(MssqlMetaData(self.columns.clone()))
212 }
213
214 fn get(&mut self, i: usize) -> Result<Value, Error> {
215 Value::decode(&self.datas[i])
216 }
217}
218
219impl Connection for MssqlConnection {
220 fn get_rows(
221 &mut self,
222 sql: &str,
223 params: Vec<Value>,
224 ) -> BoxFuture<Result<Vec<Box<dyn Row>>, Error>> {
225 let sql = MssqlDriver {}.exchange(sql);
226 Box::pin(async move {
227 let mut q = Query::new(sql);
228 for x in params {
229 x.encode(&mut q)?;
230 }
231 let v = q
232 .query(
233 self.inner
234 .as_mut()
235 .ok_or_else(|| Error::from("MssqlConnection is close"))?,
236 )
237 .await
238 .map_err(|e| Error::from(e.to_string()))?;
239 let mut results = Vec::with_capacity(v.size_hint().0);
240 let s = v
241 .into_results()
242 .await
243 .map_err(|e| Error::from(e.to_string()))?;
244 for item in s {
245 for r in item {
246 let mut columns = Vec::with_capacity(r.columns().len());
247 let mut row = MssqlRow {
248 columns: Arc::new(vec![]),
249 datas: Vec::with_capacity(r.columns().len()),
250 };
251 for x in r.columns() {
252 columns.push(x.clone());
253 }
254 row.columns = Arc::new(columns);
255 for x in r {
256 row.datas.push(x);
257 }
258 results.push(Box::new(row) as Box<dyn Row>);
259 }
260 }
261 Ok(results)
262 })
263 }
264
265 fn exec(&mut self, sql: &str, params: Vec<Value>) -> BoxFuture<Result<ExecResult, Error>> {
266 let sql = MssqlDriver {}.exchange(sql);
267 Box::pin(async move {
268 let mut q = Query::new(sql);
269 for x in params {
270 x.encode(&mut q)?;
271 }
272 let v = q
273 .execute(
274 self.inner
275 .as_mut()
276 .ok_or_else(|| Error::from("MssqlConnection is close"))?,
277 )
278 .await
279 .map_err(|e| Error::from(e.to_string()))?;
280 Ok(ExecResult {
281 rows_affected: {
282 let mut rows_affected = 0;
283 for x in v.rows_affected() {
284 rows_affected += x.clone();
285 }
286 rows_affected
287 },
288 last_insert_id: Value::Null,
289 })
290 })
291 }
292
293 fn close(&mut self) -> BoxFuture<Result<(), Error>> {
294 Box::pin(async move {
295 if let Some(v) = self.inner.take() {
297 v.close().await.map_err(|e| Error::from(e.to_string()))?;
298 }
299 Ok(())
300 })
301 }
302
303 fn ping(&mut self) -> BoxFuture<Result<(), rbdc::Error>> {
304 Box::pin(async move {
306 self.inner
307 .as_mut()
308 .ok_or_else(|| Error::from("MssqlConnection is close"))?
309 .query("select 1", &[])
310 .await
311 .map_err(|e| Error::from(e.to_string()))?;
312 Ok(())
313 })
314 }
315
316 fn begin(&mut self) -> BoxFuture<Result<(), Error>> {
317 Box::pin(async move {
318 self.inner
319 .as_mut()
320 .ok_or_else(|| Error::from("MssqlConnection is close"))?
321 .simple_query("begin tran")
322 .await
323 .map_err(|e| Error::from(e.to_string()))?;
324 Ok(())
325 })
326 }
327
328 fn commit(&mut self) -> BoxFuture<Result<(), Error>> {
329 Box::pin(async move {
330 self.inner
331 .as_mut()
332 .ok_or_else(|| Error::from("MssqlConnection is close"))?
333 .simple_query("commit")
334 .await
335 .map_err(|e| Error::from(e.to_string()))?;
336 Ok(())
337 })
338 }
339
340 fn rollback(&mut self) -> BoxFuture<Result<(), Error>> {
341 Box::pin(async move {
342 self.inner
343 .as_mut()
344 .ok_or_else(|| Error::from("MssqlConnection is close"))?
345 .simple_query("rollback")
346 .await
347 .map_err(|e| Error::from(e.to_string()))?;
348 Ok(())
349 })
350 }
351}
352
353#[cfg(test)]
354mod test {
355 use crate::driver::MssqlDriver;
356 use crate::{MssqlConnectOptions, parse_url_connection_string};
357 use rbdc::db::{Driver, ConnectOptions};
358 use tiberius::Config;
359
360 #[test]
361 fn test_datetime() {}
362
363 #[test]
364 fn test_connection_string_parsing() {
365 let jdbc_uri = "jdbc:sqlserver://localhost:1433;User=SA;Password={TestPass!123456};Database=master;";
367 let mut options = MssqlConnectOptions(Config::new());
368 let result = options.set_uri(jdbc_uri);
369 assert!(result.is_ok(), "JDBC format should be supported");
370
371 let mssql_uri = "mssql://SA:TestPass!123456@localhost:1433/master";
373 let mut options = MssqlConnectOptions(Config::new());
374 let result = options.set_uri(mssql_uri);
375 assert!(result.is_ok(), "mssql:// format should be supported: {:?}", result);
376
377 let sqlserver_uri = "sqlserver://SA:TestPass!123456@localhost:1433/master";
379 let mut options = MssqlConnectOptions(Config::new());
380 let result = options.set_uri(sqlserver_uri);
381 assert!(result.is_ok(), "sqlserver:// format should be supported: {:?}", result);
382
383 let ado_uri = "Server=localhost,1433;User Id=SA;Password=TestPass!123456;Database=master;";
385 let mut options = MssqlConnectOptions(Config::new());
386 let result = options.set_uri(ado_uri);
387 assert!(result.is_ok(), "ADO format should be supported");
388 }
389
390 #[test]
391 fn test_url_parsing_details() {
392 let config = parse_url_connection_string("mssql://testuser:testpass@example.com:1433/testdb").unwrap();
394 assert_eq!(config.get_addr(), "example.com:1433");
395
396 let config = parse_url_connection_string("mssql://testuser@localhost:1433/testdb").unwrap();
398 assert_eq!(config.get_addr(), "localhost:1433");
399
400 let config = parse_url_connection_string("mssql://testuser:testpass@localhost:1433").unwrap();
402 assert_eq!(config.get_addr(), "localhost:1433");
403
404 let config = parse_url_connection_string("mssql://testuser:testpass@localhost/testdb").unwrap();
406 assert_eq!(config.get_addr(), "localhost:1433");
407 }
408
409 #[test]
410 fn test_url_query_parameters() {
411 let config = parse_url_connection_string(
413 "mssql://testuser:testpass@localhost:1433/testdb?instance=SQLEXPRESS&application_name=MyApp&encrypt=true&trust_cert=true&readonly=true"
414 ).unwrap();
415 assert_eq!(config.get_addr(), "localhost:1433");
416
417 let config = parse_url_connection_string(
419 "sqlserver://user:pass@server:1433/db?application_name=TestApp&encrypt=false"
420 ).unwrap();
421 assert_eq!(config.get_addr(), "server:1433");
422
423 let result = parse_url_connection_string(
425 "mssql://user:pass@localhost/db?encrypt=invalid"
426 );
427 assert!(result.is_err());
428
429 let result = parse_url_connection_string(
431 "mssql://user:pass@localhost/db?trust_cert=invalid"
432 );
433 assert!(result.is_err());
434 }
435}