Skip to main content

mssql_tls/
config.rs

1//! TLS configuration options.
2
3use std::sync::Arc;
4
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6
7/// Client authentication credentials for mutual TLS.
8///
9/// This is wrapped in an Arc because `PrivateKeyDer` doesn't implement Clone.
10#[derive(Clone)]
11pub struct ClientAuth {
12    /// Client certificate chain.
13    pub certificates: Vec<CertificateDer<'static>>,
14    /// Client private key (wrapped in Arc as it doesn't implement Clone).
15    pub key: Arc<PrivateKeyDer<'static>>,
16}
17
18impl ClientAuth {
19    /// Create new client authentication credentials.
20    pub fn new(certificates: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
21        Self {
22            certificates,
23            key: Arc::new(key),
24        }
25    }
26}
27
28impl std::fmt::Debug for ClientAuth {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("ClientAuth")
31            .field("certificates_count", &self.certificates.len())
32            .field("has_key", &true)
33            .finish()
34    }
35}
36
37/// TLS configuration for SQL Server connections.
38#[derive(Clone, Debug)]
39pub struct TlsConfig {
40    /// Whether to trust the server certificate without validation.
41    ///
42    /// **Warning:** This is insecure and should only be used for testing.
43    pub trust_server_certificate: bool,
44
45    /// Custom root certificates to trust.
46    ///
47    /// If empty, the system root certificates are used.
48    pub root_certificates: Vec<CertificateDer<'static>>,
49
50    /// Client authentication credentials for mutual TLS (TDS 8.0 client cert auth).
51    pub client_auth: Option<ClientAuth>,
52
53    /// Server hostname for certificate validation.
54    ///
55    /// If not set, the connection hostname is used.
56    pub server_name: Option<String>,
57
58    /// Minimum TLS version to accept.
59    pub min_protocol_version: TlsVersion,
60
61    /// Maximum TLS version to accept.
62    pub max_protocol_version: TlsVersion,
63
64    /// Whether to use TDS 8.0 strict mode (TLS before any TDS traffic).
65    pub strict_mode: bool,
66
67    /// Application-layer protocol negotiation (ALPN) protocols.
68    pub alpn_protocols: Vec<Vec<u8>>,
69}
70
71impl Default for TlsConfig {
72    fn default() -> Self {
73        Self {
74            trust_server_certificate: false,
75            root_certificates: Vec::new(),
76            client_auth: None,
77            server_name: None,
78            min_protocol_version: TlsVersion::Tls12,
79            max_protocol_version: TlsVersion::Tls13,
80            strict_mode: false,
81            alpn_protocols: Vec::new(),
82        }
83    }
84}
85
86impl TlsConfig {
87    /// Create a new TLS configuration with default settings.
88    #[must_use]
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Trust the server certificate without validation.
94    ///
95    /// **Warning:** This is insecure and should only be used for testing.
96    #[must_use]
97    pub fn trust_server_certificate(mut self, trust: bool) -> Self {
98        self.trust_server_certificate = trust;
99        self
100    }
101
102    /// Add a custom root certificate to trust.
103    #[must_use]
104    pub fn add_root_certificate(mut self, cert: CertificateDer<'static>) -> Self {
105        self.root_certificates.push(cert);
106        self
107    }
108
109    /// Set custom root certificates, replacing any existing ones.
110    #[must_use]
111    pub fn with_root_certificates(mut self, certs: Vec<CertificateDer<'static>>) -> Self {
112        self.root_certificates = certs;
113        self
114    }
115
116    /// Set client certificate and key for mutual TLS.
117    #[must_use]
118    pub fn with_client_auth(
119        mut self,
120        certs: Vec<CertificateDer<'static>>,
121        key: PrivateKeyDer<'static>,
122    ) -> Self {
123        self.client_auth = Some(ClientAuth::new(certs, key));
124        self
125    }
126
127    /// Set the server name for certificate validation.
128    #[must_use]
129    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
130        self.server_name = Some(name.into());
131        self
132    }
133
134    /// Set the minimum TLS version.
135    #[must_use]
136    pub fn min_protocol_version(mut self, version: TlsVersion) -> Self {
137        self.min_protocol_version = version;
138        self
139    }
140
141    /// Set the maximum TLS version.
142    #[must_use]
143    pub fn max_protocol_version(mut self, version: TlsVersion) -> Self {
144        self.max_protocol_version = version;
145        self
146    }
147
148    /// Enable TDS 8.0 strict mode.
149    #[must_use]
150    pub fn strict_mode(mut self, enabled: bool) -> Self {
151        self.strict_mode = enabled;
152        self
153    }
154
155    /// Set ALPN protocols.
156    #[must_use]
157    pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
158        self.alpn_protocols = protocols;
159        self
160    }
161
162    /// Check if client certificate authentication is configured.
163    #[must_use]
164    pub fn has_client_auth(&self) -> bool {
165        self.client_auth.is_some()
166    }
167
168    /// Add a root certificate from DER-encoded bytes.
169    ///
170    /// This is a convenience method that avoids requiring a direct
171    /// dependency on the `rustls` crate. For PEM-encoded certificates,
172    /// parse them first using the `rustls-pemfile` crate.
173    #[must_use]
174    pub fn add_root_certificate_der(self, der_bytes: Vec<u8>) -> Self {
175        self.add_root_certificate(CertificateDer::from(der_bytes))
176    }
177
178    /// Set client certificate and key from DER-encoded bytes.
179    ///
180    /// This is a convenience method that avoids requiring a direct
181    /// dependency on the `rustls` crate.
182    ///
183    /// * `cert_chain_der` - DER-encoded certificate chain
184    /// * `private_key_der` - DER-encoded private key (PKCS#8 format)
185    #[must_use]
186    pub fn with_client_auth_der(
187        self,
188        cert_chain_der: Vec<Vec<u8>>,
189        private_key_der: Vec<u8>,
190    ) -> Self {
191        let certs = cert_chain_der
192            .into_iter()
193            .map(CertificateDer::from)
194            .collect();
195        let key = PrivateKeyDer::Pkcs8(private_key_der.into());
196        self.with_client_auth(certs, key)
197    }
198}
199
200/// TLS protocol version.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
202#[non_exhaustive]
203pub enum TlsVersion {
204    /// TLS 1.2
205    #[default]
206    Tls12,
207    /// TLS 1.3
208    Tls13,
209}
210
211impl TlsVersion {
212    /// Convert to rustls protocol version.
213    #[must_use]
214    pub fn to_rustls(&self) -> &'static rustls::SupportedProtocolVersion {
215        match self {
216            Self::Tls12 => &rustls::version::TLS12,
217            Self::Tls13 => &rustls::version::TLS13,
218        }
219    }
220}