1#![allow(mismatched_lifetime_syntaxes)]
2
3pub extern crate tiberius;
4
5pub mod decode;
6pub mod driver;
7pub mod encode;
8
9pub use crate::driver::MssqlDriver;
10pub use crate::driver::MssqlDriver as Driver;
11
12use crate::decode::Decode;
13use crate::encode::Encode;
14use futures_core::future::BoxFuture;
15use futures_core::Stream;
16use percent_encoding::percent_decode_str;
17use rbdc::db::{ConnectOptions, Connection, ExecResult, MetaData, Placeholder, Row};
18use rbdc::Error;
19use rbs::Value;
20use std::sync::Arc;
21use tiberius::{AuthMethod, Client, Column, ColumnData, Config, EncryptionLevel, Query};
22use tokio::net::TcpStream;
23use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
24use url::Url;
25
26pub struct MssqlConnection {
27 inner: Option<Client<Compat<TcpStream>>>,
28}
29
30impl MssqlConnection {
31 pub async fn establish(cfg: &Config) -> Result<Self, Error> {
33 let tcp = TcpStream::connect(cfg.get_addr())
35 .await
36 .map_err(|e| Error::from(e.to_string()))?;
37 tcp.set_nodelay(true)?;
38 let c = Client::connect(cfg.clone(), tcp.compat_write())
39 .await
40 .map_err(|e| Error::from(e.to_string()))?;
41 Ok(Self { inner: Some(c) })
42 }
43}
44
45#[derive(Debug)]
46pub struct MssqlConnectOptions(pub Config);
47
48impl ConnectOptions for MssqlConnectOptions {
49 fn connect(&self) -> BoxFuture<'_, Result<Box<dyn Connection>, Error>> {
50 Box::pin(async move {
51 let v = MssqlConnection::establish(&self.0)
52 .await
53 .map_err(|e| Error::from(e.to_string()))?;
54 Ok(Box::new(v) as Box<dyn Connection>)
55 })
56 }
57
58 fn set_uri(&mut self, url: &str) -> Result<(), Error> {
59 if url.contains("jdbc") {
60 let mut config =
61 Config::from_jdbc_string(url).map_err(|e| Error::from(e.to_string()))?;
62 config.trust_cert();
63 *self = MssqlConnectOptions(config);
64 } else if url.starts_with("mssql://") || url.starts_with("sqlserver://") {
65 let mut config = parse_url_connection_string(url)?;
66 config.trust_cert();
67 *self = MssqlConnectOptions(config);
68 } else {
69 let mut config =
70 Config::from_ado_string(url).map_err(|e| Error::from(e.to_string()))?;
71 config.trust_cert();
72 *self = MssqlConnectOptions(config);
73 }
74 Ok(())
75 }
76}
77
78fn parse_url_connection_string(url: &str) -> Result<Config, Error> {
89 let parsed_url = Url::parse(url).map_err(|e| Error::from(e.to_string()))?;
90
91 let mut config = Config::new();
92
93 if let Some(host) = parsed_url.host_str() {
95 config.host(host.to_string());
96 }
97
98 if let Some(port) = parsed_url.port() {
100 config.port(port);
101 }
102
103 let username = parsed_url.username();
105 if !username.is_empty() {
106 let decoded_username = percent_decode_str(username)
107 .decode_utf8()
108 .map_err(|e| Error::from(e.to_string()))?;
109
110 if let Some(password) = parsed_url.password() {
111 let decoded_password = percent_decode_str(password)
112 .decode_utf8()
113 .map_err(|e| Error::from(e.to_string()))?;
114 config.authentication(AuthMethod::sql_server(&decoded_username, &decoded_password));
115 } else {
116 config.authentication(AuthMethod::sql_server(&decoded_username, ""));
117 }
118 }
119
120 let path = parsed_url.path().trim_start_matches('/');
122 if !path.is_empty() {
123 config.database(path);
124 }
125
126 for (key, value) in parsed_url.query_pairs() {
128 match key.to_lowercase().as_str() {
129 "instance" | "instance_name" => {
130 config.instance_name(&*value);
131 }
132 "application_name" | "applicationname" => {
133 config.application_name(&*value);
134 }
135 "encrypt" | "encryption" => match value.to_lowercase().as_str() {
136 "true" | "yes" => {
137 #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
138 config.encryption(EncryptionLevel::Required);
139 }
140 "false" | "no" => {
141 #[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
142 config.encryption(EncryptionLevel::Off);
143 }
144 "danger_plaintext" => {
145 config.encryption(EncryptionLevel::NotSupported);
146 }
147 _ => {
148 return Err(Error::from(format!("Invalid encryption value: {}", value)));
149 }
150 },
151 "trust_cert" | "trustservercertificate" => {
152 match value.to_lowercase().as_str() {
153 "true" | "yes" => {
154 config.trust_cert();
155 }
156 "false" | "no" => {
157 }
159 _ => {
160 return Err(Error::from(format!("Invalid trust_cert value: {}", value)));
161 }
162 }
163 }
164 "readonly" | "applicationintent" => match value.to_lowercase().as_str() {
165 "true" | "yes" | "readonly" => {
166 config.readonly(true);
167 }
168 "false" | "no" | "readwrite" => {
169 config.readonly(false);
170 }
171 _ => {
172 return Err(Error::from(format!("Invalid readonly value: {}", value)));
173 }
174 },
175 _ => {
176 }
178 }
179 }
180
181 Ok(config)
182}
183
184#[derive(Debug)]
185pub struct MssqlRow {
186 pub columns: Arc<Vec<Column>>,
187 pub datas: Vec<ColumnData<'static>>,
188}
189
190#[derive(Debug)]
191pub struct MssqlMetaData(pub Arc<Vec<Column>>);
192
193impl MetaData for MssqlMetaData {
194 fn column_len(&self) -> usize {
195 self.0.len()
196 }
197
198 fn column_name(&self, i: usize) -> String {
199 self.0[i].name().to_string()
200 }
201
202 fn column_type(&self, i: usize) -> String {
203 format!("{:?}", self.0[i].column_type())
204 }
205}
206
207impl Row for MssqlRow {
208 fn meta_data(&self) -> Box<dyn MetaData> {
209 Box::new(MssqlMetaData(self.columns.clone()))
210 }
211
212 fn get(&mut self, i: usize) -> Result<Value, Error> {
213 Value::decode(&self.datas[i])
214 }
215}
216
217impl Connection for MssqlConnection {
218 fn get_rows(
219 &mut self,
220 sql: &str,
221 params: Vec<Value>,
222 ) -> BoxFuture<'_, Result<Vec<Box<dyn Row>>, Error>> {
223 let sql = MssqlDriver {}.exchange(sql);
224 Box::pin(async move {
225 let mut q = Query::new(sql);
226 for x in params {
227 x.encode(&mut q)?;
228 }
229 let v = q
230 .query(
231 self.inner
232 .as_mut()
233 .ok_or_else(|| Error::from("MssqlConnection is close"))?,
234 )
235 .await
236 .map_err(|e| Error::from(e.to_string()))?;
237 let mut results = Vec::with_capacity(v.size_hint().0);
238 let s = v
239 .into_results()
240 .await
241 .map_err(|e| Error::from(e.to_string()))?;
242 for item in s {
243 for r in item {
244 let mut columns = Vec::with_capacity(r.columns().len());
245 let mut row = MssqlRow {
246 columns: Arc::new(vec![]),
247 datas: Vec::with_capacity(r.columns().len()),
248 };
249 for x in r.columns() {
250 columns.push(x.clone());
251 }
252 row.columns = Arc::new(columns);
253 for x in r {
254 row.datas.push(x);
255 }
256 results.push(Box::new(row) as Box<dyn Row>);
257 }
258 }
259 Ok(results)
260 })
261 }
262
263 fn exec(&mut self, sql: &str, params: Vec<Value>) -> BoxFuture<'_, Result<ExecResult, Error>> {
264 let sql = MssqlDriver {}.exchange(sql);
265 Box::pin(async move {
266 let mut q = Query::new(sql);
267 for x in params {
268 x.encode(&mut q)?;
269 }
270 let v = q
271 .execute(
272 self.inner
273 .as_mut()
274 .ok_or_else(|| Error::from("MssqlConnection is close"))?,
275 )
276 .await
277 .map_err(|e| Error::from(e.to_string()))?;
278 Ok(ExecResult {
279 rows_affected: {
280 let mut rows_affected = 0;
281 for x in v.rows_affected() {
282 rows_affected += x.clone();
283 }
284 rows_affected
285 },
286 last_insert_id: Value::Null,
287 })
288 })
289 }
290
291 fn close(&mut self) -> BoxFuture<'_, Result<(), Error>> {
292 Box::pin(async move {
293 if let Some(v) = self.inner.take() {
295 v.close().await.map_err(|e| Error::from(e.to_string()))?;
296 }
297 Ok(())
298 })
299 }
300
301 fn ping(&mut self) -> BoxFuture<'_, Result<(), rbdc::Error>> {
302 Box::pin(async move {
304 self.inner
305 .as_mut()
306 .ok_or_else(|| Error::from("MssqlConnection is close"))?
307 .query("select 1", &[])
308 .await
309 .map_err(|e| Error::from(e.to_string()))?;
310 Ok(())
311 })
312 }
313
314 fn begin(&mut self) -> BoxFuture<'_, Result<(), Error>> {
315 Box::pin(async move {
316 self.inner
317 .as_mut()
318 .ok_or_else(|| Error::from("MssqlConnection is close"))?
319 .simple_query("begin tran")
320 .await
321 .map_err(|e| Error::from(e.to_string()))?;
322 Ok(())
323 })
324 }
325
326 fn commit(&mut self) -> BoxFuture<'_, Result<(), Error>> {
327 Box::pin(async move {
328 self.inner
329 .as_mut()
330 .ok_or_else(|| Error::from("MssqlConnection is close"))?
331 .simple_query("commit")
332 .await
333 .map_err(|e| Error::from(e.to_string()))?;
334 Ok(())
335 })
336 }
337
338 fn rollback(&mut self) -> BoxFuture<'_, Result<(), Error>> {
339 Box::pin(async move {
340 self.inner
341 .as_mut()
342 .ok_or_else(|| Error::from("MssqlConnection is close"))?
343 .simple_query("rollback")
344 .await
345 .map_err(|e| Error::from(e.to_string()))?;
346 Ok(())
347 })
348 }
349}
350
351#[cfg(test)]
352mod test {
353 use crate::{parse_url_connection_string, MssqlConnectOptions};
354 use rbdc::db::ConnectOptions;
355 use tiberius::Config;
356
357 #[test]
358 fn test_datetime() {}
359
360 #[test]
361 fn test_connection_string_parsing() {
362 let jdbc_uri =
364 "jdbc:sqlserver://localhost:1433;User=SA;Password={TestPass!123456};Database=master;";
365 let mut options = MssqlConnectOptions(Config::new());
366 let result = options.set_uri(jdbc_uri);
367 assert!(result.is_ok(), "JDBC format should be supported");
368
369 let mssql_uri = "mssql://SA:TestPass!123456@localhost:1433/master";
371 let mut options = MssqlConnectOptions(Config::new());
372 let result = options.set_uri(mssql_uri);
373 assert!(
374 result.is_ok(),
375 "mssql:// format should be supported: {:?}",
376 result
377 );
378
379 let sqlserver_uri = "sqlserver://SA:TestPass!123456@localhost:1433/master";
381 let mut options = MssqlConnectOptions(Config::new());
382 let result = options.set_uri(sqlserver_uri);
383 assert!(
384 result.is_ok(),
385 "sqlserver:// format should be supported: {:?}",
386 result
387 );
388
389 let ado_uri = "Server=localhost,1433;User Id=SA;Password=TestPass!123456;Database=master;";
391 let mut options = MssqlConnectOptions(Config::new());
392 let result = options.set_uri(ado_uri);
393 assert!(result.is_ok(), "ADO format should be supported");
394 }
395
396 #[test]
397 fn test_url_parsing_details() {
398 let config =
400 parse_url_connection_string("mssql://testuser:testpass@example.com:1433/testdb")
401 .unwrap();
402 assert_eq!(config.get_addr(), "example.com:1433");
403
404 let config = parse_url_connection_string("mssql://testuser@localhost:1433/testdb").unwrap();
406 assert_eq!(config.get_addr(), "localhost:1433");
407
408 let config =
410 parse_url_connection_string("mssql://testuser:testpass@localhost:1433").unwrap();
411 assert_eq!(config.get_addr(), "localhost:1433");
412
413 let config =
415 parse_url_connection_string("mssql://testuser:testpass@localhost/testdb").unwrap();
416 assert_eq!(config.get_addr(), "localhost:1433");
417 }
418
419 #[test]
420 fn test_url_query_parameters() {
421 let config = parse_url_connection_string(
423 "mssql://testuser:testpass@localhost:1433/testdb?instance=SQLEXPRESS&application_name=MyApp&encrypt=true&trust_cert=true&readonly=true"
424 ).unwrap();
425 assert_eq!(config.get_addr(), "localhost:1433");
426
427 let config = parse_url_connection_string(
429 "sqlserver://user:pass@server:1433/db?application_name=TestApp&encrypt=false",
430 )
431 .unwrap();
432 assert_eq!(config.get_addr(), "server:1433");
433
434 let result = parse_url_connection_string("mssql://user:pass@localhost/db?encrypt=invalid");
436 assert!(result.is_err());
437
438 let result =
440 parse_url_connection_string("mssql://user:pass@localhost/db?trust_cert=invalid");
441 assert!(result.is_err());
442 }
443}