Skip to main content

mssql_tls/
config.rs

1//! TLS configuration options.
2
3use std::sync::Arc;
4
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6
7/// Client authentication credentials for mutual TLS.
8///
9/// This is wrapped in an Arc because `PrivateKeyDer` doesn't implement Clone.
10#[derive(Clone)]
11pub struct ClientAuth {
12    /// Client certificate chain.
13    pub certificates: Vec<CertificateDer<'static>>,
14    /// Client private key (wrapped in Arc as it doesn't implement Clone).
15    pub key: Arc<PrivateKeyDer<'static>>,
16}
17
18impl ClientAuth {
19    /// Create new client authentication credentials.
20    pub fn new(certificates: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
21        Self {
22            certificates,
23            key: Arc::new(key),
24        }
25    }
26}
27
28impl std::fmt::Debug for ClientAuth {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("ClientAuth")
31            .field("certificates_count", &self.certificates.len())
32            .field("has_key", &true)
33            .finish()
34    }
35}
36
37/// TLS configuration for SQL Server connections.
38#[derive(Clone, Debug)]
39pub struct TlsConfig {
40    /// Whether to trust the server certificate without validation.
41    ///
42    /// **Warning:** This is insecure and should only be used for testing.
43    pub trust_server_certificate: bool,
44
45    /// Custom root certificates to trust.
46    ///
47    /// If empty, the bundled webpki (Mozilla) root certificates are used. The
48    /// driver does not read the operating system trust store.
49    pub root_certificates: Vec<CertificateDer<'static>>,
50
51    /// Client authentication credentials for mutual TLS (TDS 8.0 client cert auth).
52    pub client_auth: Option<ClientAuth>,
53
54    /// Server hostname for certificate validation.
55    ///
56    /// If not set, the connection hostname is used.
57    pub server_name: Option<String>,
58
59    /// Minimum TLS version to accept.
60    pub min_protocol_version: TlsVersion,
61
62    /// Maximum TLS version to accept.
63    pub max_protocol_version: TlsVersion,
64
65    /// Whether to use TDS 8.0 strict mode (TLS before any TDS traffic).
66    pub strict_mode: bool,
67
68    /// Application-layer protocol negotiation (ALPN) protocols.
69    pub alpn_protocols: Vec<Vec<u8>>,
70}
71
72impl Default for TlsConfig {
73    fn default() -> Self {
74        Self {
75            trust_server_certificate: false,
76            root_certificates: Vec::new(),
77            client_auth: None,
78            server_name: None,
79            min_protocol_version: TlsVersion::Tls12,
80            max_protocol_version: TlsVersion::Tls13,
81            strict_mode: false,
82            alpn_protocols: Vec::new(),
83        }
84    }
85}
86
87impl TlsConfig {
88    /// Create a new TLS configuration with default settings.
89    #[must_use]
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    /// Trust the server certificate without validation.
95    ///
96    /// **Warning:** This is insecure and should only be used for testing.
97    #[must_use]
98    pub fn trust_server_certificate(mut self, trust: bool) -> Self {
99        self.trust_server_certificate = trust;
100        self
101    }
102
103    /// Add a custom root certificate to trust.
104    #[must_use]
105    pub fn add_root_certificate(mut self, cert: CertificateDer<'static>) -> Self {
106        self.root_certificates.push(cert);
107        self
108    }
109
110    /// Set custom root certificates, replacing any existing ones.
111    #[must_use]
112    pub fn with_root_certificates(mut self, certs: Vec<CertificateDer<'static>>) -> Self {
113        self.root_certificates = certs;
114        self
115    }
116
117    /// Set client certificate and key for mutual TLS.
118    #[must_use]
119    pub fn with_client_auth(
120        mut self,
121        certs: Vec<CertificateDer<'static>>,
122        key: PrivateKeyDer<'static>,
123    ) -> Self {
124        self.client_auth = Some(ClientAuth::new(certs, key));
125        self
126    }
127
128    /// Set the server name for certificate validation.
129    #[must_use]
130    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
131        self.server_name = Some(name.into());
132        self
133    }
134
135    /// Set the minimum TLS version.
136    #[must_use]
137    pub fn min_protocol_version(mut self, version: TlsVersion) -> Self {
138        self.min_protocol_version = version;
139        self
140    }
141
142    /// Set the maximum TLS version.
143    #[must_use]
144    pub fn max_protocol_version(mut self, version: TlsVersion) -> Self {
145        self.max_protocol_version = version;
146        self
147    }
148
149    /// Enable TDS 8.0 strict mode.
150    #[must_use]
151    pub fn strict_mode(mut self, enabled: bool) -> Self {
152        self.strict_mode = enabled;
153        self
154    }
155
156    /// Set ALPN protocols.
157    #[must_use]
158    pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
159        self.alpn_protocols = protocols;
160        self
161    }
162
163    /// Check if client certificate authentication is configured.
164    #[must_use]
165    pub fn has_client_auth(&self) -> bool {
166        self.client_auth.is_some()
167    }
168
169    /// Add a root certificate from DER-encoded bytes.
170    ///
171    /// This is a convenience method that avoids requiring a direct
172    /// dependency on the `rustls` crate. For PEM-encoded certificates,
173    /// parse them first using the `rustls-pemfile` crate.
174    #[must_use]
175    pub fn add_root_certificate_der(self, der_bytes: Vec<u8>) -> Self {
176        self.add_root_certificate(CertificateDer::from(der_bytes))
177    }
178
179    /// Set client certificate and key from DER-encoded bytes.
180    ///
181    /// This is a convenience method that avoids requiring a direct
182    /// dependency on the `rustls` crate.
183    ///
184    /// * `cert_chain_der` - DER-encoded certificate chain
185    /// * `private_key_der` - DER-encoded private key (PKCS#8 format)
186    #[must_use]
187    pub fn with_client_auth_der(
188        self,
189        cert_chain_der: Vec<Vec<u8>>,
190        private_key_der: Vec<u8>,
191    ) -> Self {
192        let certs = cert_chain_der
193            .into_iter()
194            .map(CertificateDer::from)
195            .collect();
196        let key = PrivateKeyDer::Pkcs8(private_key_der.into());
197        self.with_client_auth(certs, key)
198    }
199}
200
201/// TLS protocol version.
202#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
203#[non_exhaustive]
204pub enum TlsVersion {
205    /// TLS 1.2
206    #[default]
207    Tls12,
208    /// TLS 1.3
209    Tls13,
210}
211
212impl TlsVersion {
213    /// Convert to rustls protocol version.
214    #[must_use]
215    pub fn to_rustls(&self) -> &'static rustls::SupportedProtocolVersion {
216        match self {
217            Self::Tls12 => &rustls::version::TLS12,
218            Self::Tls13 => &rustls::version::TLS13,
219        }
220    }
221}