Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,no_run
23//! # async fn with_always_encrypted() -> Result<(), Box<dyn std::error::Error>> {
24//! # #[cfg(feature = "always-encrypted")]
25//! # {
26//! # let conn_str = "Server=localhost;Database=db;Encrypt=strict;Column Encryption Setting=Enabled";
27//! use mssql_client::{Client, Config, EncryptionConfig};
28//! use mssql_auth::InMemoryKeyStore;
29//!
30//! // Register a key-store provider, then attach it to the connection config.
31//! let key_store = InMemoryKeyStore::new();
32//! let encryption_config = EncryptionConfig::new().with_provider(key_store);
33//!
34//! let config = Config::from_connection_string(conn_str)?
35//!     .with_column_encryption(encryption_config);
36//!
37//! let _client = Client::connect(config).await?;
38//! # }
39//! # Ok(())
40//! # }
41//! ```
42//!
43//! Equivalently, set `Column Encryption Setting=Enabled` in the connection
44//! string. Production-ready providers ship in `mssql-auth`: `InMemoryKeyStore`
45//! (dev/test), `AzureKeyVaultProvider` (`azure-identity` feature), and
46//! `WindowsCertStoreProvider` (`sspi-auth`, Windows). Implement
47//! [`mssql_auth::KeyStoreProvider`] for custom key storage. Do **not** substitute
48//! T-SQL `ENCRYPTBYKEY` — the server can see that plaintext, defeating the point.
49//!
50//! ## How decryption works
51//!
52//! 1. Always Encrypted support is negotiated in LOGIN7 (`FEATURE_EXT`).
53//! 2. `ColMetaData` carries [`CryptoMetadata`] and the [`CekTable`]; column
54//!    encryption keys are resolved asynchronously up front (calling the key-store
55//!    providers).
56//! 3. Each encrypted cell is decrypted during row parsing via
57//!    AEAD_AES_256_CBC_HMAC_SHA256, with the HMAC verified before decryption.
58//!
59//! Reads are transparent across `query`, `call_procedure`, the procedure
60//! builder, and multi-result queries. **Limitation:** the parameter (write) path
61//! only supports `NULL`; encrypting outbound parameter values is not yet
62//! implemented.
63//!
64//! ## Security Model
65//!
66//! - **Client-only decryption**: SQL Server never sees plaintext data
67//! - **DBA protection**: Even database administrators cannot read encrypted data
68//! - **Key separation**: CMK stays in secure key store, never transmitted
69
70use std::collections::HashMap;
71
72use mssql_auth::KeyStoreProvider;
73use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
74
75#[cfg(feature = "always-encrypted")]
76use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
77#[cfg(feature = "always-encrypted")]
78use std::sync::Arc;
79
80/// Configuration for Always Encrypted feature.
81#[derive(Default)]
82pub struct EncryptionConfig {
83    /// Whether encryption is enabled.
84    pub enabled: bool,
85    /// Registered key store providers.
86    providers: Vec<Box<dyn KeyStoreProvider>>,
87    /// Whether to cache decrypted CEKs for performance.
88    pub cache_ceks: bool,
89}
90
91impl EncryptionConfig {
92    /// Create a new encryption configuration (disabled by default).
93    #[must_use]
94    pub fn new() -> Self {
95        Self {
96            enabled: true,
97            providers: Vec::new(),
98            cache_ceks: true,
99        }
100    }
101
102    /// Register a key store provider.
103    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
104        self.providers.push(Box::new(provider));
105    }
106
107    /// Builder method to add a key store provider.
108    #[must_use]
109    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
110        self.register_provider(provider);
111        self
112    }
113
114    /// Enable or disable CEK caching.
115    #[must_use]
116    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
117        self.cache_ceks = enabled;
118        self
119    }
120
121    /// Get a provider by name.
122    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
123        self.providers
124            .iter()
125            .find(|p| p.provider_name() == name)
126            .map(|p| p.as_ref())
127    }
128
129    /// Check if encryption is ready (enabled and has providers).
130    #[must_use]
131    pub fn is_ready(&self) -> bool {
132        self.enabled && !self.providers.is_empty()
133    }
134}
135
136impl std::fmt::Debug for EncryptionConfig {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("EncryptionConfig")
139            .field("enabled", &self.enabled)
140            .field("provider_count", &self.providers.len())
141            .field("cache_ceks", &self.cache_ceks)
142            .finish()
143    }
144}
145
146/// Runtime context for encryption operations.
147///
148/// This is the active encryption state for a connected client,
149/// including resolved CEKs and encryptors.
150///
151/// The context holds an `Arc<EncryptionConfig>` so providers remain accessible
152/// across connection retries/redirects where the `Config` (and its inner
153/// encryption config Arc) gets cloned multiple times.
154#[cfg(feature = "always-encrypted")]
155pub struct EncryptionContext {
156    /// Shared handle on the user-supplied configuration. Providers are looked
157    /// up by name through this reference, so an arbitrary number of `Arc`
158    /// clones do not lose access to them.
159    config: std::sync::Arc<EncryptionConfig>,
160    /// Cache for decrypted CEKs.
161    cek_cache: CekCache,
162    /// Whether caching is enabled.
163    cache_enabled: bool,
164}
165
166#[cfg(feature = "always-encrypted")]
167impl EncryptionContext {
168    /// Create a new encryption context from an Arc-wrapped configuration.
169    ///
170    /// The Arc is retained by the context so provider lookups continue to
171    /// work for the lifetime of the client — regardless of how many times
172    /// the outer `Config` has been cloned for retry/redirect handling.
173    pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
174        let cache_enabled = config.cache_ceks;
175        Self {
176            config,
177            cek_cache: CekCache::new(),
178            cache_enabled,
179        }
180    }
181
182    /// Create a new encryption context from configuration.
183    pub fn new(config: EncryptionConfig) -> Self {
184        Self::from_arc(std::sync::Arc::new(config))
185    }
186
187    /// Get or decrypt a CEK for a column.
188    ///
189    /// This handles the CEK caching and decryption logic:
190    /// 1. Check cache for existing encryptor
191    /// 2. If not cached, decrypt CEK using the appropriate key store
192    /// 3. Create and cache the encryptor
193    pub async fn get_encryptor(
194        &self,
195        cek_entry: &CekTableEntry,
196    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
197        let cache_key = CekCacheKey::new(
198            cek_entry.database_id,
199            cek_entry.cek_id,
200            cek_entry.cek_version,
201        );
202
203        // Check cache first
204        if self.cache_enabled {
205            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
206                return Ok(encryptor);
207            }
208        }
209
210        // Get the primary CEK value
211        let cek_value = cek_entry
212            .primary_value()
213            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
214
215        // Find the appropriate key store provider via the shared config
216        let provider = self
217            .config
218            .get_provider(&cek_value.key_store_provider_name)
219            .ok_or_else(|| {
220                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
221            })?;
222
223        // Decrypt the CEK
224        let decrypted_cek = provider
225            .decrypt_cek(
226                &cek_value.cmk_path,
227                &cek_value.encryption_algorithm,
228                &cek_value.encrypted_value,
229            )
230            .await?;
231
232        // Create encryptor and cache it
233        if self.cache_enabled {
234            self.cek_cache.insert(cache_key, decrypted_cek)
235        } else {
236            // Create encryptor without caching
237            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
238        }
239    }
240
241    /// Encrypt a value for a column.
242    ///
243    /// # Arguments
244    ///
245    /// * `plaintext` - The plaintext value to encrypt
246    /// * `cek_entry` - The CEK table entry for this column
247    /// * `encryption_type` - Deterministic or randomized encryption
248    pub async fn encrypt_value(
249        &self,
250        plaintext: &[u8],
251        cek_entry: &CekTableEntry,
252        encryption_type: EncryptionTypeWire,
253    ) -> Result<Vec<u8>, EncryptionError> {
254        let encryptor = self.get_encryptor(cek_entry).await?;
255
256        let enc_type = match encryption_type {
257            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
258            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
259            _ => {
260                return Err(EncryptionError::UnsupportedOperation(format!(
261                    "unsupported encryption type: {encryption_type:?}"
262                )));
263            }
264        };
265
266        encryptor.encrypt(plaintext, enc_type)
267    }
268
269    /// Decrypt a value from an encrypted column.
270    ///
271    /// # Arguments
272    ///
273    /// * `ciphertext` - The encrypted value
274    /// * `cek_entry` - The CEK table entry for this column
275    pub async fn decrypt_value(
276        &self,
277        ciphertext: &[u8],
278        cek_entry: &CekTableEntry,
279    ) -> Result<Vec<u8>, EncryptionError> {
280        let encryptor = self.get_encryptor(cek_entry).await?;
281        encryptor.decrypt(ciphertext)
282    }
283
284    /// Clear the CEK cache.
285    ///
286    /// Call this when keys may have been rotated.
287    pub fn clear_cache(&self) {
288        self.cek_cache.clear();
289    }
290
291    /// Check if a provider is registered.
292    pub fn has_provider(&self, name: &str) -> bool {
293        self.config.get_provider(name).is_some()
294    }
295}
296
297#[cfg(feature = "always-encrypted")]
298impl std::fmt::Debug for EncryptionContext {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.debug_struct("EncryptionContext")
301            .field("provider_count", &self.config.providers.len())
302            .field("cache_entries", &self.cek_cache.len())
303            .field("cache_enabled", &self.cache_enabled)
304            .finish()
305    }
306}
307
308/// Column encryption metadata for a result set.
309///
310/// This combines the CEK table with per-column crypto metadata,
311/// providing all information needed to decrypt result columns.
312#[derive(Debug, Clone)]
313pub struct ResultSetEncryptionInfo {
314    /// CEK table for this result set.
315    pub cek_table: CekTable,
316    /// Crypto metadata for each column (index matches column ordinal).
317    pub column_crypto: Vec<Option<CryptoMetadata>>,
318}
319
320impl ResultSetEncryptionInfo {
321    /// Create encryption info for a result set.
322    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
323        Self {
324            cek_table,
325            column_crypto: vec![None; column_count],
326        }
327    }
328
329    /// Set crypto metadata for a column.
330    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
331        if ordinal < self.column_crypto.len() {
332            self.column_crypto[ordinal] = Some(metadata);
333        }
334    }
335
336    /// Get the CEK entry for a column.
337    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
338        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
339        self.cek_table.get(crypto.cek_table_ordinal)
340    }
341
342    /// Check if a column is encrypted.
343    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
344        self.column_crypto
345            .get(ordinal)
346            .map(|c| c.is_some())
347            .unwrap_or(false)
348    }
349
350    /// Get the encryption type for a column.
351    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
352        self.column_crypto
353            .get(ordinal)?
354            .as_ref()
355            .map(|c| c.encryption_type)
356    }
357}
358
359/// Parameter encryption metadata for a query.
360///
361/// This is returned by `sp_describe_parameter_encryption` and describes
362/// how each parameter should be encrypted.
363#[derive(Debug, Clone)]
364pub struct ParameterEncryptionInfo {
365    /// CEK table for parameters.
366    pub cek_table: CekTable,
367    /// Mapping from parameter name to crypto metadata.
368    pub parameters: HashMap<String, ParameterCryptoInfo>,
369}
370
371impl ParameterEncryptionInfo {
372    /// Create empty parameter encryption info.
373    pub fn new() -> Self {
374        Self {
375            cek_table: CekTable::new(),
376            parameters: HashMap::new(),
377        }
378    }
379
380    /// Add encryption info for a parameter.
381    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
382        self.parameters.insert(name, info);
383    }
384
385    /// Get encryption info for a parameter.
386    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
387        self.parameters.get(name)
388    }
389
390    /// Check if a parameter needs encryption.
391    pub fn needs_encryption(&self, name: &str) -> bool {
392        self.parameters.contains_key(name)
393    }
394}
395
396impl Default for ParameterEncryptionInfo {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402/// Encryption metadata for a single parameter.
403#[derive(Debug, Clone)]
404pub struct ParameterCryptoInfo {
405    /// Index into the CEK table.
406    pub cek_ordinal: u16,
407    /// Encryption type (deterministic or randomized).
408    pub encryption_type: EncryptionTypeWire,
409    /// Algorithm ID.
410    pub algorithm_id: u8,
411    /// Target column ordinal in the table (for type information).
412    pub column_ordinal: u16,
413    /// Target column database ID.
414    pub database_id: u32,
415}
416
417impl ParameterCryptoInfo {
418    /// Create new parameter crypto info.
419    pub fn new(
420        cek_ordinal: u16,
421        encryption_type: EncryptionTypeWire,
422        algorithm_id: u8,
423        column_ordinal: u16,
424        database_id: u32,
425    ) -> Self {
426        Self {
427            cek_ordinal,
428            encryption_type,
429            algorithm_id,
430            column_ordinal,
431            database_id,
432        }
433    }
434}
435
436#[cfg(test)]
437#[allow(clippy::unwrap_used, clippy::expect_used)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_encryption_config_defaults() {
443        let config = EncryptionConfig::new();
444        assert!(config.enabled);
445        assert!(config.cache_ceks);
446        assert!(!config.is_ready()); // No providers
447    }
448
449    #[test]
450    fn test_result_set_encryption_info() {
451        let cek_table = CekTable::new();
452        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
453
454        assert!(!info.is_column_encrypted(0));
455        assert!(!info.is_column_encrypted(1));
456        assert!(!info.is_column_encrypted(2));
457
458        let metadata = CryptoMetadata {
459            cek_table_ordinal: 0,
460            base_user_type: 0,
461            base_col_type: 0x26,
462            base_type_info: tds_protocol::token::TypeInfo::default(),
463            algorithm_id: 2,
464            encryption_type: EncryptionTypeWire::Deterministic,
465            normalization_version: 1,
466        };
467
468        info.set_column_crypto(1, metadata);
469        assert!(!info.is_column_encrypted(0));
470        assert!(info.is_column_encrypted(1));
471        assert!(!info.is_column_encrypted(2));
472
473        assert_eq!(
474            info.get_encryption_type(1),
475            Some(EncryptionTypeWire::Deterministic)
476        );
477    }
478
479    #[test]
480    fn test_parameter_encryption_info() {
481        let mut info = ParameterEncryptionInfo::new();
482
483        assert!(!info.needs_encryption("@p1"));
484
485        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
486        info.add_parameter("@p1".to_string(), crypto);
487
488        assert!(info.needs_encryption("@p1"));
489        assert!(!info.needs_encryption("@p2"));
490
491        let param = info.get_parameter("@p1").unwrap();
492        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
493    }
494}