sqlint/connector/
mssql.rs

1mod conversion;
2mod error;
3
4use super::{IsolationLevel, TransactionOptions};
5use crate::{
6    ast::{Query, Value},
7    connector::{metrics, queryable::*, ResultSet, Transaction},
8    error::{Error, ErrorKind},
9    visitor::{self, Visitor},
10};
11use async_trait::async_trait;
12use connection_string::JdbcString;
13use futures::lock::Mutex;
14use std::{
15    convert::TryFrom,
16    fmt,
17    future::Future,
18    str::FromStr,
19    sync::atomic::{AtomicBool, Ordering},
20    time::Duration,
21};
22use tiberius::*;
23use tokio::net::TcpStream;
24use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
25
26/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature.
27#[cfg(feature = "expose-drivers")]
28pub use tiberius;
29
30/// Wraps a connection url and exposes the parsing logic used by Sqlint,
31/// including default values.
32#[derive(Debug, Clone)]
33#[cfg_attr(feature = "docs", doc(cfg(feature = "mssql")))]
34pub struct MssqlUrl {
35    connection_string: String,
36    query_params: MssqlQueryParams,
37}
38
39/// TLS mode when connecting to SQL Server.
40#[derive(Debug, Clone, Copy)]
41#[cfg_attr(feature = "docs", doc(cfg(feature = "mssql")))]
42pub enum EncryptMode {
43    /// All traffic is encrypted.
44    On,
45    /// Only the login credentials are encrypted.
46    Off,
47    /// Nothing is encrypted.
48    DangerPlainText,
49}
50
51impl fmt::Display for EncryptMode {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            Self::On => write!(f, "true"),
55            Self::Off => write!(f, "false"),
56            Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"),
57        }
58    }
59}
60
61impl FromStr for EncryptMode {
62    type Err = Error;
63
64    fn from_str(s: &str) -> crate::Result<Self> {
65        let mode = match s.parse::<bool>() {
66            Ok(true) => Self::On,
67            _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText,
68            _ => Self::Off,
69        };
70
71        Ok(mode)
72    }
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct MssqlQueryParams {
77    encrypt: EncryptMode,
78    port: Option<u16>,
79    host: Option<String>,
80    user: Option<String>,
81    password: Option<String>,
82    database: String,
83    schema: String,
84    trust_server_certificate: bool,
85    trust_server_certificate_ca: Option<String>,
86    connection_limit: Option<usize>,
87    socket_timeout: Option<Duration>,
88    connect_timeout: Option<Duration>,
89    pool_timeout: Option<Duration>,
90    transaction_isolation_level: Option<IsolationLevel>,
91    max_connection_lifetime: Option<Duration>,
92    max_idle_connection_lifetime: Option<Duration>,
93}
94
95static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted;
96
97#[async_trait]
98impl TransactionCapable for Mssql {
99    async fn start_transaction(&self, isolation: Option<IsolationLevel>) -> crate::Result<Transaction<'_>> {
100        // Isolation levels in SQL Server are set on the connection and live until they're changed.
101        // Always explicitly setting the isolation level each time a tx is started (either to the given value
102        // or by using the default/connection string value) prevents transactions started on connections from
103        // the pool to have unexpected isolation levels set.
104        let isolation =
105            isolation.or(self.url.query_params.transaction_isolation_level).or(Some(SQL_SERVER_DEFAULT_ISOLATION));
106
107        let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
108
109        Transaction::new(self, self.begin_statement(), opts).await
110    }
111}
112
113impl MssqlUrl {
114    /// Maximum number of connections the pool can have (if used together with
115    /// pooled Sqlint).
116    pub fn connection_limit(&self) -> Option<usize> {
117        self.query_params.connection_limit()
118    }
119
120    /// A duration how long one query can take.
121    pub fn socket_timeout(&self) -> Option<Duration> {
122        self.query_params.socket_timeout()
123    }
124
125    /// A duration how long we can try to connect to the database.
126    pub fn connect_timeout(&self) -> Option<Duration> {
127        self.query_params.connect_timeout()
128    }
129
130    /// A pool check_out timeout.
131    pub fn pool_timeout(&self) -> Option<Duration> {
132        self.query_params.pool_timeout()
133    }
134
135    /// The isolation level of a transaction.
136    fn transaction_isolation_level(&self) -> Option<IsolationLevel> {
137        self.query_params.transaction_isolation_level
138    }
139
140    /// Name of the database.
141    pub fn dbname(&self) -> &str {
142        self.query_params.database()
143    }
144
145    /// The prefix which to use when querying database.
146    pub fn schema(&self) -> &str {
147        self.query_params.schema()
148    }
149
150    /// Database hostname.
151    pub fn host(&self) -> &str {
152        self.query_params.host()
153    }
154
155    /// The username to use when connecting to the database.
156    pub fn username(&self) -> Option<&str> {
157        self.query_params.user()
158    }
159
160    /// The password to use when connecting to the database.
161    pub fn password(&self) -> Option<&str> {
162        self.query_params.password()
163    }
164
165    /// The TLS mode to use when connecting to the database.
166    pub fn encrypt(&self) -> EncryptMode {
167        self.query_params.encrypt()
168    }
169
170    /// If true, we allow invalid certificates (self-signed, or otherwise
171    /// dangerous) when connecting. Should be true only for development and
172    /// testing.
173    pub fn trust_server_certificate(&self) -> bool {
174        self.query_params.trust_server_certificate()
175    }
176
177    /// Path to a custom server certificate file.
178    pub fn trust_server_certificate_ca(&self) -> Option<&str> {
179        self.query_params.trust_server_certificate_ca()
180    }
181
182    /// Database port.
183    pub fn port(&self) -> u16 {
184        self.query_params.port()
185    }
186
187    /// The JDBC connection string
188    pub fn connection_string(&self) -> &str {
189        &self.connection_string
190    }
191
192    /// The maximum connection lifetime
193    pub fn max_connection_lifetime(&self) -> Option<Duration> {
194        self.query_params.max_connection_lifetime()
195    }
196
197    /// The maximum idle connection lifetime
198    pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
199        self.query_params.max_idle_connection_lifetime()
200    }
201}
202
203impl MssqlQueryParams {
204    fn port(&self) -> u16 {
205        self.port.unwrap_or(1433)
206    }
207
208    fn host(&self) -> &str {
209        self.host.as_deref().unwrap_or("localhost")
210    }
211
212    fn user(&self) -> Option<&str> {
213        self.user.as_deref()
214    }
215
216    fn password(&self) -> Option<&str> {
217        self.password.as_deref()
218    }
219
220    fn encrypt(&self) -> EncryptMode {
221        self.encrypt
222    }
223
224    fn trust_server_certificate(&self) -> bool {
225        self.trust_server_certificate
226    }
227
228    fn trust_server_certificate_ca(&self) -> Option<&str> {
229        self.trust_server_certificate_ca.as_deref()
230    }
231
232    fn database(&self) -> &str {
233        &self.database
234    }
235
236    fn schema(&self) -> &str {
237        &self.schema
238    }
239
240    fn socket_timeout(&self) -> Option<Duration> {
241        self.socket_timeout
242    }
243
244    fn connect_timeout(&self) -> Option<Duration> {
245        self.connect_timeout
246    }
247
248    fn connection_limit(&self) -> Option<usize> {
249        self.connection_limit
250    }
251
252    fn pool_timeout(&self) -> Option<Duration> {
253        self.pool_timeout
254    }
255
256    fn max_connection_lifetime(&self) -> Option<Duration> {
257        self.max_connection_lifetime
258    }
259
260    fn max_idle_connection_lifetime(&self) -> Option<Duration> {
261        self.max_idle_connection_lifetime
262    }
263}
264
265/// A connector interface for the SQL Server database.
266#[derive(Debug)]
267#[cfg_attr(feature = "docs", doc(cfg(feature = "mssql")))]
268pub struct Mssql {
269    client: Mutex<Client<Compat<TcpStream>>>,
270    url: MssqlUrl,
271    socket_timeout: Option<Duration>,
272    is_healthy: AtomicBool,
273}
274
275impl Mssql {
276    /// Creates a new connection to SQL Server.
277    pub async fn new(url: MssqlUrl) -> crate::Result<Self> {
278        let config = Config::from_jdbc_string(&url.connection_string)?;
279        let tcp = TcpStream::connect_named(&config).await?;
280        let socket_timeout = url.socket_timeout();
281
282        let connecting = async {
283            match Client::connect(config, tcp.compat_write()).await {
284                Ok(client) => Ok(client),
285                Err(tiberius::error::Error::Routing { host, port }) => {
286                    let mut config = Config::from_jdbc_string(&url.connection_string)?;
287                    config.host(host);
288                    config.port(port);
289
290                    let tcp = TcpStream::connect_named(&config).await?;
291                    Client::connect(config, tcp.compat_write()).await
292                }
293                Err(e) => Err(e),
294            }
295        };
296
297        let client = super::timeout::connect(url.connect_timeout(), connecting).await?;
298
299        let this = Self { client: Mutex::new(client), url, socket_timeout, is_healthy: AtomicBool::new(true) };
300
301        if let Some(isolation) = this.url.transaction_isolation_level() {
302            this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")).await?;
303        };
304
305        Ok(this)
306    }
307
308    /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature.
309    /// This is a lower level API when you need to get into database specific features.
310    #[cfg(feature = "expose-drivers")]
311    pub fn client(&self) -> &Mutex<Client<Compat<TcpStream>>> {
312        &self.client
313    }
314
315    async fn perform_io<F, T>(&self, fut: F) -> crate::Result<T>
316    where
317        F: Future<Output = std::result::Result<T, tiberius::error::Error>>,
318    {
319        match super::timeout::socket(self.socket_timeout, fut).await {
320            Err(e) if e.is_closed() => {
321                self.is_healthy.store(false, Ordering::SeqCst);
322                Err(e)
323            }
324            res => res,
325        }
326    }
327}
328
329#[async_trait]
330impl Queryable for Mssql {
331    async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
332        let (sql, params) = visitor::Mssql::build(q)?;
333        self.query_raw(&sql, &params[..]).await
334    }
335
336    async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
337        metrics::query("mssql.query_raw", sql, params, move || async move {
338            let mut client = self.client.lock().await;
339
340            let mut query = tiberius::Query::new(sql);
341
342            for param in params {
343                query.bind(param);
344            }
345
346            let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?;
347
348            match results.pop() {
349                Some(rows) => {
350                    let mut columns_set = false;
351                    let mut columns = Vec::new();
352                    let mut result_rows = Vec::with_capacity(rows.len());
353
354                    for row in rows.into_iter() {
355                        if !columns_set {
356                            columns = row.columns().iter().map(|c| c.name().to_string()).collect();
357                            columns_set = true;
358                        }
359
360                        let mut values: Vec<Value<'_>> = Vec::with_capacity(row.len());
361
362                        for val in row.into_iter() {
363                            values.push(Value::try_from(val)?);
364                        }
365
366                        result_rows.push(values);
367                    }
368
369                    Ok(ResultSet::new(columns, result_rows))
370                }
371                None => Ok(ResultSet::new(Vec::new(), Vec::new())),
372            }
373        })
374        .await
375    }
376
377    async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
378        self.query_raw(sql, params).await
379    }
380
381    async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
382        let (sql, params) = visitor::Mssql::build(q)?;
383        self.execute_raw(&sql, &params[..]).await
384    }
385
386    async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
387        metrics::query("mssql.execute_raw", sql, params, move || async move {
388            let mut query = tiberius::Query::new(sql);
389
390            for param in params {
391                query.bind(param);
392            }
393
394            let mut client = self.client.lock().await;
395            let changes = self.perform_io(query.execute(&mut client)).await?.total();
396
397            Ok(changes)
398        })
399        .await
400    }
401
402    async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
403        self.execute_raw(sql, params).await
404    }
405
406    async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
407        metrics::query("mssql.raw_cmd", cmd, &[], move || async move {
408            let mut client = self.client.lock().await;
409            self.perform_io(client.simple_query(cmd)).await?.into_results().await?;
410            Ok(())
411        })
412        .await
413    }
414
415    async fn version(&self) -> crate::Result<Option<String>> {
416        let query = r#"SELECT @@VERSION AS version"#;
417        let rows = self.query_raw(query, &[]).await?;
418
419        let version_string = rows.get(0).and_then(|row| row.get("version").and_then(|version| version.to_string()));
420
421        Ok(version_string)
422    }
423
424    fn is_healthy(&self) -> bool {
425        self.is_healthy.load(Ordering::SeqCst)
426    }
427
428    async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
429        self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")).await?;
430
431        Ok(())
432    }
433
434    fn begin_statement(&self) -> &'static str {
435        "BEGIN TRAN"
436    }
437
438    fn requires_isolation_first(&self) -> bool {
439        true
440    }
441}
442
443impl MssqlUrl {
444    pub fn new(jdbc_connection_string: &str) -> crate::Result<Self> {
445        let query_params = Self::parse_query_params(jdbc_connection_string)?;
446        let connection_string = Self::with_jdbc_prefix(jdbc_connection_string);
447
448        Ok(Self { connection_string, query_params })
449    }
450
451    fn with_jdbc_prefix(input: &str) -> String {
452        if input.starts_with("jdbc:sqlserver") {
453            input.into()
454        } else {
455            format!("jdbc:{input}")
456        }
457    }
458
459    fn parse_query_params(input: &str) -> crate::Result<MssqlQueryParams> {
460        let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?;
461
462        let host = conn.server_name().map(|server_name| match conn.instance_name() {
463            Some(instance_name) => format!(r#"{server_name}\{instance_name}"#),
464            None => server_name.to_string(),
465        });
466
467        let port = conn.port();
468        let props = conn.properties_mut();
469        let user = props.remove("user");
470        let password = props.remove("password");
471        let database = props.remove("database").unwrap_or_else(|| String::from("master"));
472        let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo"));
473
474        let connection_limit = props
475            .remove("connectionlimit")
476            .or_else(|| props.remove("connection_limit"))
477            .map(|param| param.parse())
478            .transpose()?;
479
480        let transaction_isolation_level = props
481            .remove("isolationlevel")
482            .or_else(|| props.remove("isolation_level"))
483            .map(|level| {
484                IsolationLevel::from_str(&level).map_err(|_| {
485                    let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`"));
486                    Error::builder(kind).build()
487                })
488            })
489            .transpose()?;
490
491        let mut connect_timeout = props
492            .remove("logintimeout")
493            .or_else(|| props.remove("login_timeout"))
494            .or_else(|| props.remove("connecttimeout"))
495            .or_else(|| props.remove("connect_timeout"))
496            .or_else(|| props.remove("connectiontimeout"))
497            .or_else(|| props.remove("connection_timeout"))
498            .map(|param| param.parse().map(Duration::from_secs))
499            .transpose()?;
500
501        match connect_timeout {
502            None => connect_timeout = Some(Duration::from_secs(5)),
503            Some(dur) if dur.as_secs() == 0 => connect_timeout = None,
504            _ => (),
505        }
506
507        let mut pool_timeout = props
508            .remove("pooltimeout")
509            .or_else(|| props.remove("pool_timeout"))
510            .map(|param| param.parse().map(Duration::from_secs))
511            .transpose()?;
512
513        match pool_timeout {
514            None => pool_timeout = Some(Duration::from_secs(10)),
515            Some(dur) if dur.as_secs() == 0 => pool_timeout = None,
516            _ => (),
517        }
518
519        let socket_timeout = props
520            .remove("sockettimeout")
521            .or_else(|| props.remove("socket_timeout"))
522            .map(|param| param.parse().map(Duration::from_secs))
523            .transpose()?;
524
525        let encrypt =
526            props.remove("encrypt").map(|param| EncryptMode::from_str(&param)).transpose()?.unwrap_or(EncryptMode::On);
527
528        let trust_server_certificate = props
529            .remove("trustservercertificate")
530            .or_else(|| props.remove("trust_server_certificate"))
531            .map(|param| param.parse())
532            .transpose()?
533            .unwrap_or(false);
534
535        let trust_server_certificate_ca: Option<String> =
536            props.remove("trustservercertificateca").or_else(|| props.remove("trust_server_certificate_ca"));
537
538        let mut max_connection_lifetime =
539            props.remove("max_connection_lifetime").map(|param| param.parse().map(Duration::from_secs)).transpose()?;
540
541        match max_connection_lifetime {
542            Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None,
543            _ => (),
544        }
545
546        let mut max_idle_connection_lifetime = props
547            .remove("max_idle_connection_lifetime")
548            .map(|param| param.parse().map(Duration::from_secs))
549            .transpose()?;
550
551        match max_idle_connection_lifetime {
552            None => max_idle_connection_lifetime = Some(Duration::from_secs(300)),
553            Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None,
554            _ => (),
555        }
556
557        Ok(MssqlQueryParams {
558            encrypt,
559            port,
560            host,
561            user,
562            password,
563            database,
564            schema,
565            trust_server_certificate,
566            trust_server_certificate_ca,
567            connection_limit,
568            socket_timeout,
569            connect_timeout,
570            pool_timeout,
571            transaction_isolation_level,
572            max_connection_lifetime,
573            max_idle_connection_lifetime,
574        })
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use crate::tests::test_api::mssql::CONN_STR;
581    use crate::{error::*, single::Sqlint};
582
583    #[tokio::test]
584    async fn should_map_wrong_credentials_error() {
585        let url = CONN_STR.replace("user=SA", "user=WRONG");
586
587        let res = Sqlint::new(url.as_str()).await;
588        assert!(res.is_err());
589
590        let err = res.unwrap_err();
591        assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
592    }
593}