Skip to main content

mssql_tds/
core.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::error::Error;
5use crate::error::Error::OperationCancelledError;
6use std::future::Future;
7use std::path::PathBuf;
8use tokio_util::sync::CancellationToken;
9
10/// Alias for `Result<T, crate::error::Error>` used throughout the crate.
11pub type TdsResult<T> = Result<T, Error>;
12
13/// ALPN protocol identifier for TDS 8.0 connections.
14pub const TDS_8_ALPN_PROTOCOL: &str = "tds/8.0";
15
16/// Cooperative cancellation handle backed by a [`CancellationToken`].
17///
18/// Pass to [`TdsConnectionProvider::create_client()`](crate::connection_provider::tds_connection_provider::TdsConnectionProvider::create_client)
19/// to cancel a pending connect, or hold for later query cancellation.
20#[derive(Debug)]
21pub struct CancelHandle {
22    pub(crate) cancel_token: CancellationToken,
23}
24
25impl CancelHandle {
26    /// Create a new, uncancelled handle.
27    pub fn new() -> Self {
28        CancelHandle {
29            cancel_token: CancellationToken::new(),
30        }
31    }
32
33    /// Trigger cancellation, notifying all child handles.
34    pub fn cancel(self) {
35        self.cancel_token.cancel();
36    }
37
38    /// Derive a child handle that is cancelled when this handle is.
39    pub fn child_handle(&self) -> Self {
40        Self::from(self.cancel_token.child_token())
41    }
42
43    pub(crate) async fn run_until_cancelled<F, ResultType>(
44        cancel_handle: Option<&CancelHandle>,
45        f: F,
46    ) -> F::Output
47    where
48        F: Future<Output = TdsResult<ResultType>> + Send,
49    {
50        match cancel_handle {
51            Some(handle) => match handle.cancel_token.run_until_cancelled(f).await {
52                Some(result) => result,
53                None => Err(OperationCancelledError("Request was cancelled".to_string())),
54            },
55            None => f.await,
56        }
57    }
58}
59
60impl From<CancellationToken> for CancelHandle {
61    fn from(value: CancellationToken) -> Self {
62        CancelHandle {
63            cancel_token: value,
64        }
65    }
66}
67
68impl Default for CancelHandle {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74/// SQL Server major-version discriminant derived from the server's reported version.
75#[derive(PartialEq, Debug)]
76pub enum SQLServerVersion {
77    /// Unsupported or unknown server version.
78    SqlServerNotsupported = 0,
79    /// SQL Server 2000.
80    SqlServer2000 = 8,
81    /// SQL Server 2005.
82    SqlServer2005 = 9,
83    /// SQL Server 2008 / 2008 R2.
84    SqlServer2008 = 10,
85    /// SQL Server 2012.
86    SqlServer2012 = 11,
87    /// SQL Server 2014.
88    SqlServer2014 = 12,
89    /// SQL Server 2016.
90    SqlServer2016 = 13,
91    /// SQL Server 2017.
92    SqlServer2017 = 14,
93    /// SQL Server 2019.
94    SqlServer2019 = 15,
95    /// SQL Server 2022.
96    SqlServer2022 = 16,
97    /// SQL Server 2022+ (version 17).
98    SqlServer2022lus = 17,
99}
100
101impl From<u8> for SQLServerVersion {
102    fn from(v: u8) -> Self {
103        match v {
104            0 => SQLServerVersion::SqlServerNotsupported,
105            8 => SQLServerVersion::SqlServer2000,
106            9 => SQLServerVersion::SqlServer2005,
107            10 => SQLServerVersion::SqlServer2008,
108            11 => SQLServerVersion::SqlServer2012,
109            12 => SQLServerVersion::SqlServer2014,
110            13 => SQLServerVersion::SqlServer2016,
111            14 => SQLServerVersion::SqlServer2017,
112            15 => SQLServerVersion::SqlServer2019,
113            16 => SQLServerVersion::SqlServer2022,
114            17 => SQLServerVersion::SqlServer2022lus,
115            _ => SQLServerVersion::SqlServerNotsupported,
116        }
117    }
118}
119
120/// Four-part server version reported during the TDS pre-login handshake.
121#[derive(Clone, Copy, PartialEq, Debug)]
122pub struct Version {
123    /// Major version number.
124    pub major: u8,
125    /// Minor version number.
126    pub minor: u8,
127    /// Build number.
128    pub build: u16,
129    /// Revision number.
130    pub revision: u16,
131}
132
133impl Version {
134    /// Creates a new `Version`.
135    pub fn new(major: u8, minor: u8, build: u16, revision: u16) -> Self {
136        Version {
137            major,
138            minor,
139            build,
140            revision,
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn sql_server_version_from_known_values() {
151        assert_eq!(
152            SQLServerVersion::from(0),
153            SQLServerVersion::SqlServerNotsupported
154        );
155        assert_eq!(SQLServerVersion::from(8), SQLServerVersion::SqlServer2000);
156        assert_eq!(SQLServerVersion::from(9), SQLServerVersion::SqlServer2005);
157        assert_eq!(SQLServerVersion::from(10), SQLServerVersion::SqlServer2008);
158        assert_eq!(SQLServerVersion::from(11), SQLServerVersion::SqlServer2012);
159        assert_eq!(SQLServerVersion::from(12), SQLServerVersion::SqlServer2014);
160        assert_eq!(SQLServerVersion::from(13), SQLServerVersion::SqlServer2016);
161        assert_eq!(SQLServerVersion::from(14), SQLServerVersion::SqlServer2017);
162        assert_eq!(SQLServerVersion::from(15), SQLServerVersion::SqlServer2019);
163        assert_eq!(SQLServerVersion::from(16), SQLServerVersion::SqlServer2022);
164        assert_eq!(
165            SQLServerVersion::from(17),
166            SQLServerVersion::SqlServer2022lus
167        );
168    }
169
170    #[test]
171    fn sql_server_version_from_unknown_defaults_to_not_supported() {
172        assert_eq!(
173            SQLServerVersion::from(1),
174            SQLServerVersion::SqlServerNotsupported
175        );
176        assert_eq!(
177            SQLServerVersion::from(7),
178            SQLServerVersion::SqlServerNotsupported
179        );
180        assert_eq!(
181            SQLServerVersion::from(18),
182            SQLServerVersion::SqlServerNotsupported
183        );
184        assert_eq!(
185            SQLServerVersion::from(255),
186            SQLServerVersion::SqlServerNotsupported
187        );
188    }
189
190    #[test]
191    fn cancel_handle_default() {
192        let handle = CancelHandle::default();
193        assert!(!handle.cancel_token.is_cancelled());
194    }
195
196    #[tokio::test]
197    async fn run_until_cancelled_none_handle() {
198        let result: TdsResult<i32> =
199            CancelHandle::run_until_cancelled(None, async { Ok(42) }).await;
200        assert_eq!(result.unwrap(), 42);
201    }
202
203    #[tokio::test]
204    async fn run_until_cancelled_with_handle_completes() {
205        let handle = CancelHandle::new();
206        let result: TdsResult<i32> =
207            CancelHandle::run_until_cancelled(Some(&handle), async { Ok(99) }).await;
208        assert_eq!(result.unwrap(), 99);
209    }
210}
211
212/// TLS and encryption settings for a TDS connection.
213#[derive(Clone, PartialEq, Debug)]
214pub struct EncryptionOptions {
215    /// Encryption mode negotiated with the server.
216    pub mode: EncryptionSetting,
217    /// Skip server certificate chain validation.
218    pub trust_server_certificate: bool,
219    /// Expected CN or SAN in the server certificate.
220    pub host_name_in_cert: Option<String>,
221    /// Path to a DER or PEM encoded X.509 certificate file for certificate pinning.
222    /// When specified, the driver performs an exact binary match between the provided
223    /// certificate and the server's certificate, bypassing standard CA chain validation.
224    pub server_certificate: Option<PathBuf>,
225}
226
227impl EncryptionOptions {
228    /// Creates encryption options defaulting to `Strict` mode.
229    pub fn new() -> Self {
230        EncryptionOptions {
231            mode: EncryptionSetting::Strict,
232            trust_server_certificate: false,
233            host_name_in_cert: None,
234            server_certificate: None,
235        }
236    }
237}
238
239impl Default for EncryptionOptions {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245/// Encryption level requested by the client during the TDS pre-login.
246#[derive(Clone, Copy, PartialEq, Debug)]
247pub enum EncryptionSetting {
248    /// Don't encrypt if the server allows it.
249    PreferOff,
250    /// Encrypt the connection after pre-login.
251    On,
252    /// Require encryption after pre-login (semantically identical to `On`).
253    Required,
254    /// Encrypt the entire stream including pre-login (TDS 8.0).
255    Strict,
256}
257
258#[derive(Clone, Copy, PartialEq, Debug)]
259pub(crate) enum NegotiatedEncryptionSetting {
260    Strict,
261    LoginOnly,
262    Mandatory,
263    NoEncryption,
264}