1mod conversion;
2mod error;
3
4use super::{IsolationLevel, TransactionOptions};
5use crate::{
6 ast::{Query, Value},
7 connector::{metrics, queryable::*, ResultSet, Transaction},
8 error::{Error, ErrorKind},
9 visitor::{self, Visitor},
10};
11use async_trait::async_trait;
12use connection_string::JdbcString;
13use futures::lock::Mutex;
14use std::{
15 convert::TryFrom,
16 fmt,
17 future::Future,
18 str::FromStr,
19 sync::atomic::{AtomicBool, Ordering},
20 time::Duration,
21};
22use tiberius::*;
23use tokio::net::TcpStream;
24use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
25
26#[cfg(feature = "expose-drivers")]
28pub use tiberius;
29
30#[derive(Debug, Clone)]
33#[cfg_attr(feature = "docs", doc(cfg(feature = "mssql")))]
34pub struct MssqlUrl {
35 connection_string: String,
36 query_params: MssqlQueryParams,
37}
38
39#[derive(Debug, Clone, Copy)]
41#[cfg_attr(feature = "docs", doc(cfg(feature = "mssql")))]
42pub enum EncryptMode {
43 On,
45 Off,
47 DangerPlainText,
49}
50
51impl fmt::Display for EncryptMode {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 Self::On => write!(f, "true"),
55 Self::Off => write!(f, "false"),
56 Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"),
57 }
58 }
59}
60
61impl FromStr for EncryptMode {
62 type Err = Error;
63
64 fn from_str(s: &str) -> crate::Result<Self> {
65 let mode = match s.parse::<bool>() {
66 Ok(true) => Self::On,
67 _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText,
68 _ => Self::Off,
69 };
70
71 Ok(mode)
72 }
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct MssqlQueryParams {
77 encrypt: EncryptMode,
78 port: Option<u16>,
79 host: Option<String>,
80 user: Option<String>,
81 password: Option<String>,
82 database: String,
83 schema: String,
84 trust_server_certificate: bool,
85 trust_server_certificate_ca: Option<String>,
86 connection_limit: Option<usize>,
87 socket_timeout: Option<Duration>,
88 connect_timeout: Option<Duration>,
89 pool_timeout: Option<Duration>,
90 transaction_isolation_level: Option<IsolationLevel>,
91 max_connection_lifetime: Option<Duration>,
92 max_idle_connection_lifetime: Option<Duration>,
93}
94
95static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted;
96
97#[async_trait]
98impl TransactionCapable for Mssql {
99 async fn start_transaction(&self, isolation: Option<IsolationLevel>) -> crate::Result<Transaction<'_>> {
100 let isolation =
105 isolation.or(self.url.query_params.transaction_isolation_level).or(Some(SQL_SERVER_DEFAULT_ISOLATION));
106
107 let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
108
109 Transaction::new(self, self.begin_statement(), opts).await
110 }
111}
112
113impl MssqlUrl {
114 pub fn connection_limit(&self) -> Option<usize> {
117 self.query_params.connection_limit()
118 }
119
120 pub fn socket_timeout(&self) -> Option<Duration> {
122 self.query_params.socket_timeout()
123 }
124
125 pub fn connect_timeout(&self) -> Option<Duration> {
127 self.query_params.connect_timeout()
128 }
129
130 pub fn pool_timeout(&self) -> Option<Duration> {
132 self.query_params.pool_timeout()
133 }
134
135 fn transaction_isolation_level(&self) -> Option<IsolationLevel> {
137 self.query_params.transaction_isolation_level
138 }
139
140 pub fn dbname(&self) -> &str {
142 self.query_params.database()
143 }
144
145 pub fn schema(&self) -> &str {
147 self.query_params.schema()
148 }
149
150 pub fn host(&self) -> &str {
152 self.query_params.host()
153 }
154
155 pub fn username(&self) -> Option<&str> {
157 self.query_params.user()
158 }
159
160 pub fn password(&self) -> Option<&str> {
162 self.query_params.password()
163 }
164
165 pub fn encrypt(&self) -> EncryptMode {
167 self.query_params.encrypt()
168 }
169
170 pub fn trust_server_certificate(&self) -> bool {
174 self.query_params.trust_server_certificate()
175 }
176
177 pub fn trust_server_certificate_ca(&self) -> Option<&str> {
179 self.query_params.trust_server_certificate_ca()
180 }
181
182 pub fn port(&self) -> u16 {
184 self.query_params.port()
185 }
186
187 pub fn connection_string(&self) -> &str {
189 &self.connection_string
190 }
191
192 pub fn max_connection_lifetime(&self) -> Option<Duration> {
194 self.query_params.max_connection_lifetime()
195 }
196
197 pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
199 self.query_params.max_idle_connection_lifetime()
200 }
201}
202
203impl MssqlQueryParams {
204 fn port(&self) -> u16 {
205 self.port.unwrap_or(1433)
206 }
207
208 fn host(&self) -> &str {
209 self.host.as_deref().unwrap_or("localhost")
210 }
211
212 fn user(&self) -> Option<&str> {
213 self.user.as_deref()
214 }
215
216 fn password(&self) -> Option<&str> {
217 self.password.as_deref()
218 }
219
220 fn encrypt(&self) -> EncryptMode {
221 self.encrypt
222 }
223
224 fn trust_server_certificate(&self) -> bool {
225 self.trust_server_certificate
226 }
227
228 fn trust_server_certificate_ca(&self) -> Option<&str> {
229 self.trust_server_certificate_ca.as_deref()
230 }
231
232 fn database(&self) -> &str {
233 &self.database
234 }
235
236 fn schema(&self) -> &str {
237 &self.schema
238 }
239
240 fn socket_timeout(&self) -> Option<Duration> {
241 self.socket_timeout
242 }
243
244 fn connect_timeout(&self) -> Option<Duration> {
245 self.connect_timeout
246 }
247
248 fn connection_limit(&self) -> Option<usize> {
249 self.connection_limit
250 }
251
252 fn pool_timeout(&self) -> Option<Duration> {
253 self.pool_timeout
254 }
255
256 fn max_connection_lifetime(&self) -> Option<Duration> {
257 self.max_connection_lifetime
258 }
259
260 fn max_idle_connection_lifetime(&self) -> Option<Duration> {
261 self.max_idle_connection_lifetime
262 }
263}
264
265#[derive(Debug)]
267#[cfg_attr(feature = "docs", doc(cfg(feature = "mssql")))]
268pub struct Mssql {
269 client: Mutex<Client<Compat<TcpStream>>>,
270 url: MssqlUrl,
271 socket_timeout: Option<Duration>,
272 is_healthy: AtomicBool,
273}
274
275impl Mssql {
276 pub async fn new(url: MssqlUrl) -> crate::Result<Self> {
278 let config = Config::from_jdbc_string(&url.connection_string)?;
279 let tcp = TcpStream::connect_named(&config).await?;
280 let socket_timeout = url.socket_timeout();
281
282 let connecting = async {
283 match Client::connect(config, tcp.compat_write()).await {
284 Ok(client) => Ok(client),
285 Err(tiberius::error::Error::Routing { host, port }) => {
286 let mut config = Config::from_jdbc_string(&url.connection_string)?;
287 config.host(host);
288 config.port(port);
289
290 let tcp = TcpStream::connect_named(&config).await?;
291 Client::connect(config, tcp.compat_write()).await
292 }
293 Err(e) => Err(e),
294 }
295 };
296
297 let client = super::timeout::connect(url.connect_timeout(), connecting).await?;
298
299 let this = Self { client: Mutex::new(client), url, socket_timeout, is_healthy: AtomicBool::new(true) };
300
301 if let Some(isolation) = this.url.transaction_isolation_level() {
302 this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")).await?;
303 };
304
305 Ok(this)
306 }
307
308 #[cfg(feature = "expose-drivers")]
311 pub fn client(&self) -> &Mutex<Client<Compat<TcpStream>>> {
312 &self.client
313 }
314
315 async fn perform_io<F, T>(&self, fut: F) -> crate::Result<T>
316 where
317 F: Future<Output = std::result::Result<T, tiberius::error::Error>>,
318 {
319 match super::timeout::socket(self.socket_timeout, fut).await {
320 Err(e) if e.is_closed() => {
321 self.is_healthy.store(false, Ordering::SeqCst);
322 Err(e)
323 }
324 res => res,
325 }
326 }
327}
328
329#[async_trait]
330impl Queryable for Mssql {
331 async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
332 let (sql, params) = visitor::Mssql::build(q)?;
333 self.query_raw(&sql, ¶ms[..]).await
334 }
335
336 async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
337 metrics::query("mssql.query_raw", sql, params, move || async move {
338 let mut client = self.client.lock().await;
339
340 let mut query = tiberius::Query::new(sql);
341
342 for param in params {
343 query.bind(param);
344 }
345
346 let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?;
347
348 match results.pop() {
349 Some(rows) => {
350 let mut columns_set = false;
351 let mut columns = Vec::new();
352 let mut result_rows = Vec::with_capacity(rows.len());
353
354 for row in rows.into_iter() {
355 if !columns_set {
356 columns = row.columns().iter().map(|c| c.name().to_string()).collect();
357 columns_set = true;
358 }
359
360 let mut values: Vec<Value<'_>> = Vec::with_capacity(row.len());
361
362 for val in row.into_iter() {
363 values.push(Value::try_from(val)?);
364 }
365
366 result_rows.push(values);
367 }
368
369 Ok(ResultSet::new(columns, result_rows))
370 }
371 None => Ok(ResultSet::new(Vec::new(), Vec::new())),
372 }
373 })
374 .await
375 }
376
377 async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
378 self.query_raw(sql, params).await
379 }
380
381 async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
382 let (sql, params) = visitor::Mssql::build(q)?;
383 self.execute_raw(&sql, ¶ms[..]).await
384 }
385
386 async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
387 metrics::query("mssql.execute_raw", sql, params, move || async move {
388 let mut query = tiberius::Query::new(sql);
389
390 for param in params {
391 query.bind(param);
392 }
393
394 let mut client = self.client.lock().await;
395 let changes = self.perform_io(query.execute(&mut client)).await?.total();
396
397 Ok(changes)
398 })
399 .await
400 }
401
402 async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
403 self.execute_raw(sql, params).await
404 }
405
406 async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
407 metrics::query("mssql.raw_cmd", cmd, &[], move || async move {
408 let mut client = self.client.lock().await;
409 self.perform_io(client.simple_query(cmd)).await?.into_results().await?;
410 Ok(())
411 })
412 .await
413 }
414
415 async fn version(&self) -> crate::Result<Option<String>> {
416 let query = r#"SELECT @@VERSION AS version"#;
417 let rows = self.query_raw(query, &[]).await?;
418
419 let version_string = rows.get(0).and_then(|row| row.get("version").and_then(|version| version.to_string()));
420
421 Ok(version_string)
422 }
423
424 fn is_healthy(&self) -> bool {
425 self.is_healthy.load(Ordering::SeqCst)
426 }
427
428 async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
429 self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")).await?;
430
431 Ok(())
432 }
433
434 fn begin_statement(&self) -> &'static str {
435 "BEGIN TRAN"
436 }
437
438 fn requires_isolation_first(&self) -> bool {
439 true
440 }
441}
442
443impl MssqlUrl {
444 pub fn new(jdbc_connection_string: &str) -> crate::Result<Self> {
445 let query_params = Self::parse_query_params(jdbc_connection_string)?;
446 let connection_string = Self::with_jdbc_prefix(jdbc_connection_string);
447
448 Ok(Self { connection_string, query_params })
449 }
450
451 fn with_jdbc_prefix(input: &str) -> String {
452 if input.starts_with("jdbc:sqlserver") {
453 input.into()
454 } else {
455 format!("jdbc:{input}")
456 }
457 }
458
459 fn parse_query_params(input: &str) -> crate::Result<MssqlQueryParams> {
460 let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?;
461
462 let host = conn.server_name().map(|server_name| match conn.instance_name() {
463 Some(instance_name) => format!(r#"{server_name}\{instance_name}"#),
464 None => server_name.to_string(),
465 });
466
467 let port = conn.port();
468 let props = conn.properties_mut();
469 let user = props.remove("user");
470 let password = props.remove("password");
471 let database = props.remove("database").unwrap_or_else(|| String::from("master"));
472 let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo"));
473
474 let connection_limit = props
475 .remove("connectionlimit")
476 .or_else(|| props.remove("connection_limit"))
477 .map(|param| param.parse())
478 .transpose()?;
479
480 let transaction_isolation_level = props
481 .remove("isolationlevel")
482 .or_else(|| props.remove("isolation_level"))
483 .map(|level| {
484 IsolationLevel::from_str(&level).map_err(|_| {
485 let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`"));
486 Error::builder(kind).build()
487 })
488 })
489 .transpose()?;
490
491 let mut connect_timeout = props
492 .remove("logintimeout")
493 .or_else(|| props.remove("login_timeout"))
494 .or_else(|| props.remove("connecttimeout"))
495 .or_else(|| props.remove("connect_timeout"))
496 .or_else(|| props.remove("connectiontimeout"))
497 .or_else(|| props.remove("connection_timeout"))
498 .map(|param| param.parse().map(Duration::from_secs))
499 .transpose()?;
500
501 match connect_timeout {
502 None => connect_timeout = Some(Duration::from_secs(5)),
503 Some(dur) if dur.as_secs() == 0 => connect_timeout = None,
504 _ => (),
505 }
506
507 let mut pool_timeout = props
508 .remove("pooltimeout")
509 .or_else(|| props.remove("pool_timeout"))
510 .map(|param| param.parse().map(Duration::from_secs))
511 .transpose()?;
512
513 match pool_timeout {
514 None => pool_timeout = Some(Duration::from_secs(10)),
515 Some(dur) if dur.as_secs() == 0 => pool_timeout = None,
516 _ => (),
517 }
518
519 let socket_timeout = props
520 .remove("sockettimeout")
521 .or_else(|| props.remove("socket_timeout"))
522 .map(|param| param.parse().map(Duration::from_secs))
523 .transpose()?;
524
525 let encrypt =
526 props.remove("encrypt").map(|param| EncryptMode::from_str(¶m)).transpose()?.unwrap_or(EncryptMode::On);
527
528 let trust_server_certificate = props
529 .remove("trustservercertificate")
530 .or_else(|| props.remove("trust_server_certificate"))
531 .map(|param| param.parse())
532 .transpose()?
533 .unwrap_or(false);
534
535 let trust_server_certificate_ca: Option<String> =
536 props.remove("trustservercertificateca").or_else(|| props.remove("trust_server_certificate_ca"));
537
538 let mut max_connection_lifetime =
539 props.remove("max_connection_lifetime").map(|param| param.parse().map(Duration::from_secs)).transpose()?;
540
541 match max_connection_lifetime {
542 Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None,
543 _ => (),
544 }
545
546 let mut max_idle_connection_lifetime = props
547 .remove("max_idle_connection_lifetime")
548 .map(|param| param.parse().map(Duration::from_secs))
549 .transpose()?;
550
551 match max_idle_connection_lifetime {
552 None => max_idle_connection_lifetime = Some(Duration::from_secs(300)),
553 Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None,
554 _ => (),
555 }
556
557 Ok(MssqlQueryParams {
558 encrypt,
559 port,
560 host,
561 user,
562 password,
563 database,
564 schema,
565 trust_server_certificate,
566 trust_server_certificate_ca,
567 connection_limit,
568 socket_timeout,
569 connect_timeout,
570 pool_timeout,
571 transaction_isolation_level,
572 max_connection_lifetime,
573 max_idle_connection_lifetime,
574 })
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use crate::tests::test_api::mssql::CONN_STR;
581 use crate::{error::*, single::Sqlint};
582
583 #[tokio::test]
584 async fn should_map_wrong_credentials_error() {
585 let url = CONN_STR.replace("user=SA", "user=WRONG");
586
587 let res = Sqlint::new(url.as_str()).await;
588 assert!(res.is_err());
589
590 let err = res.unwrap_err();
591 assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
592 }
593}