Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,ignore
23//! use mssql_client::{Config, EncryptionConfig};
24//! use mssql_auth::InMemoryKeyStore;
25//!
26//! // Create encryption configuration
27//! let mut key_store = InMemoryKeyStore::new();
28//! key_store.add_key("MyKey", &pem)?;
29//!
30//! let encryption_config = EncryptionConfig::new()
31//!     .with_provider(key_store)
32//!     .build();
33//!
34//! // Connect with encryption enabled
35//! let config = Config::from_connection_string(conn_str)?
36//!     .with_encryption(encryption_config);
37//!
38//! let client = Client::connect(config).await?;
39//! ```
40//!
41//! ## Security Model
42//!
43//! - **Client-only decryption**: SQL Server never sees plaintext data
44//! - **DBA protection**: Even database administrators cannot read encrypted data
45//! - **Key separation**: CMK stays in secure key store, never transmitted
46
47use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57/// Configuration for Always Encrypted feature.
58#[derive(Default)]
59pub struct EncryptionConfig {
60    /// Whether encryption is enabled.
61    pub enabled: bool,
62    /// Registered key store providers.
63    providers: Vec<Box<dyn KeyStoreProvider>>,
64    /// Whether to cache decrypted CEKs for performance.
65    pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69    /// Create a new encryption configuration (disabled by default).
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            enabled: true,
74            providers: Vec::new(),
75            cache_ceks: true,
76        }
77    }
78
79    /// Register a key store provider.
80    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81        self.providers.push(Box::new(provider));
82    }
83
84    /// Builder method to add a key store provider.
85    #[must_use]
86    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87        self.register_provider(provider);
88        self
89    }
90
91    /// Enable or disable CEK caching.
92    #[must_use]
93    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94        self.cache_ceks = enabled;
95        self
96    }
97
98    /// Get a provider by name.
99    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100        self.providers
101            .iter()
102            .find(|p| p.provider_name() == name)
103            .map(|p| p.as_ref())
104    }
105
106    /// Check if encryption is ready (enabled and has providers).
107    #[must_use]
108    pub fn is_ready(&self) -> bool {
109        self.enabled && !self.providers.is_empty()
110    }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("EncryptionConfig")
116            .field("enabled", &self.enabled)
117            .field("provider_count", &self.providers.len())
118            .field("cache_ceks", &self.cache_ceks)
119            .finish()
120    }
121}
122
123/// Runtime context for encryption operations.
124///
125/// This is the active encryption state for a connected client,
126/// including resolved CEKs and encryptors.
127#[cfg(feature = "always-encrypted")]
128pub struct EncryptionContext {
129    /// Key store providers by name.
130    providers: HashMap<String, Box<dyn KeyStoreProvider>>,
131    /// Cache for decrypted CEKs.
132    cek_cache: CekCache,
133    /// Whether caching is enabled.
134    cache_enabled: bool,
135}
136
137#[cfg(feature = "always-encrypted")]
138impl EncryptionContext {
139    /// Create a new encryption context from an Arc-wrapped configuration.
140    ///
141    /// This attempts to unwrap the Arc to get ownership of the config.
142    /// If the Arc has been cloned (multiple references), it falls back
143    /// to creating a context with no providers (connection string-only mode
144    /// where providers must be registered separately).
145    pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
146        match std::sync::Arc::try_unwrap(config) {
147            Ok(owned) => Self::new(owned),
148            Err(_arc) => {
149                // Config was shared — create context without providers.
150                // The caller should register providers separately.
151                tracing::warn!(
152                    "EncryptionConfig has multiple references; \
153                     creating EncryptionContext without providers"
154                );
155                Self {
156                    providers: std::collections::HashMap::new(),
157                    cek_cache: CekCache::new(),
158                    cache_enabled: true,
159                }
160            }
161        }
162    }
163
164    /// Create a new encryption context from configuration.
165    pub fn new(config: EncryptionConfig) -> Self {
166        let providers = config
167            .providers
168            .into_iter()
169            .map(|p| (p.provider_name().to_string(), p))
170            .collect();
171
172        Self {
173            providers,
174            cek_cache: CekCache::new(),
175            cache_enabled: config.cache_ceks,
176        }
177    }
178
179    /// Get or decrypt a CEK for a column.
180    ///
181    /// This handles the CEK caching and decryption logic:
182    /// 1. Check cache for existing encryptor
183    /// 2. If not cached, decrypt CEK using the appropriate key store
184    /// 3. Create and cache the encryptor
185    pub async fn get_encryptor(
186        &self,
187        cek_entry: &CekTableEntry,
188    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
189        let cache_key = CekCacheKey::new(
190            cek_entry.database_id,
191            cek_entry.cek_id,
192            cek_entry.cek_version,
193        );
194
195        // Check cache first
196        if self.cache_enabled {
197            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
198                return Ok(encryptor);
199            }
200        }
201
202        // Get the primary CEK value
203        let cek_value = cek_entry
204            .primary_value()
205            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
206
207        // Find the appropriate key store provider
208        let provider = self
209            .providers
210            .get(&cek_value.key_store_provider_name)
211            .ok_or_else(|| {
212                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
213            })?;
214
215        // Decrypt the CEK
216        let decrypted_cek = provider
217            .decrypt_cek(
218                &cek_value.cmk_path,
219                &cek_value.encryption_algorithm,
220                &cek_value.encrypted_value,
221            )
222            .await?;
223
224        // Create encryptor and cache it
225        if self.cache_enabled {
226            self.cek_cache.insert(cache_key, decrypted_cek)
227        } else {
228            // Create encryptor without caching
229            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
230        }
231    }
232
233    /// Encrypt a value for a column.
234    ///
235    /// # Arguments
236    ///
237    /// * `plaintext` - The plaintext value to encrypt
238    /// * `cek_entry` - The CEK table entry for this column
239    /// * `encryption_type` - Deterministic or randomized encryption
240    pub async fn encrypt_value(
241        &self,
242        plaintext: &[u8],
243        cek_entry: &CekTableEntry,
244        encryption_type: EncryptionTypeWire,
245    ) -> Result<Vec<u8>, EncryptionError> {
246        let encryptor = self.get_encryptor(cek_entry).await?;
247
248        let enc_type = match encryption_type {
249            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
250            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
251            _ => {
252                return Err(EncryptionError::UnsupportedOperation(format!(
253                    "unsupported encryption type: {encryption_type:?}"
254                )));
255            }
256        };
257
258        encryptor.encrypt(plaintext, enc_type)
259    }
260
261    /// Decrypt a value from an encrypted column.
262    ///
263    /// # Arguments
264    ///
265    /// * `ciphertext` - The encrypted value
266    /// * `cek_entry` - The CEK table entry for this column
267    pub async fn decrypt_value(
268        &self,
269        ciphertext: &[u8],
270        cek_entry: &CekTableEntry,
271    ) -> Result<Vec<u8>, EncryptionError> {
272        let encryptor = self.get_encryptor(cek_entry).await?;
273        encryptor.decrypt(ciphertext)
274    }
275
276    /// Clear the CEK cache.
277    ///
278    /// Call this when keys may have been rotated.
279    pub fn clear_cache(&self) {
280        self.cek_cache.clear();
281    }
282
283    /// Check if a provider is registered.
284    pub fn has_provider(&self, name: &str) -> bool {
285        self.providers.contains_key(name)
286    }
287}
288
289#[cfg(feature = "always-encrypted")]
290impl std::fmt::Debug for EncryptionContext {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        f.debug_struct("EncryptionContext")
293            .field("providers", &self.providers.keys().collect::<Vec<_>>())
294            .field("cache_entries", &self.cek_cache.len())
295            .field("cache_enabled", &self.cache_enabled)
296            .finish()
297    }
298}
299
300/// Column encryption metadata for a result set.
301///
302/// This combines the CEK table with per-column crypto metadata,
303/// providing all information needed to decrypt result columns.
304#[derive(Debug, Clone)]
305pub struct ResultSetEncryptionInfo {
306    /// CEK table for this result set.
307    pub cek_table: CekTable,
308    /// Crypto metadata for each column (index matches column ordinal).
309    pub column_crypto: Vec<Option<CryptoMetadata>>,
310}
311
312impl ResultSetEncryptionInfo {
313    /// Create encryption info for a result set.
314    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
315        Self {
316            cek_table,
317            column_crypto: vec![None; column_count],
318        }
319    }
320
321    /// Set crypto metadata for a column.
322    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
323        if ordinal < self.column_crypto.len() {
324            self.column_crypto[ordinal] = Some(metadata);
325        }
326    }
327
328    /// Get the CEK entry for a column.
329    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
330        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
331        self.cek_table.get(crypto.cek_table_ordinal)
332    }
333
334    /// Check if a column is encrypted.
335    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
336        self.column_crypto
337            .get(ordinal)
338            .map(|c| c.is_some())
339            .unwrap_or(false)
340    }
341
342    /// Get the encryption type for a column.
343    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
344        self.column_crypto
345            .get(ordinal)?
346            .as_ref()
347            .map(|c| c.encryption_type)
348    }
349}
350
351/// Parameter encryption metadata for a query.
352///
353/// This is returned by `sp_describe_parameter_encryption` and describes
354/// how each parameter should be encrypted.
355#[derive(Debug, Clone)]
356pub struct ParameterEncryptionInfo {
357    /// CEK table for parameters.
358    pub cek_table: CekTable,
359    /// Mapping from parameter name to crypto metadata.
360    pub parameters: HashMap<String, ParameterCryptoInfo>,
361}
362
363impl ParameterEncryptionInfo {
364    /// Create empty parameter encryption info.
365    pub fn new() -> Self {
366        Self {
367            cek_table: CekTable::new(),
368            parameters: HashMap::new(),
369        }
370    }
371
372    /// Add encryption info for a parameter.
373    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
374        self.parameters.insert(name, info);
375    }
376
377    /// Get encryption info for a parameter.
378    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
379        self.parameters.get(name)
380    }
381
382    /// Check if a parameter needs encryption.
383    pub fn needs_encryption(&self, name: &str) -> bool {
384        self.parameters.contains_key(name)
385    }
386}
387
388impl Default for ParameterEncryptionInfo {
389    fn default() -> Self {
390        Self::new()
391    }
392}
393
394/// Encryption metadata for a single parameter.
395#[derive(Debug, Clone)]
396pub struct ParameterCryptoInfo {
397    /// Index into the CEK table.
398    pub cek_ordinal: u16,
399    /// Encryption type (deterministic or randomized).
400    pub encryption_type: EncryptionTypeWire,
401    /// Algorithm ID.
402    pub algorithm_id: u8,
403    /// Target column ordinal in the table (for type information).
404    pub column_ordinal: u16,
405    /// Target column database ID.
406    pub database_id: u32,
407}
408
409impl ParameterCryptoInfo {
410    /// Create new parameter crypto info.
411    pub fn new(
412        cek_ordinal: u16,
413        encryption_type: EncryptionTypeWire,
414        algorithm_id: u8,
415        column_ordinal: u16,
416        database_id: u32,
417    ) -> Self {
418        Self {
419            cek_ordinal,
420            encryption_type,
421            algorithm_id,
422            column_ordinal,
423            database_id,
424        }
425    }
426}
427
428#[cfg(test)]
429#[allow(clippy::unwrap_used, clippy::expect_used)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_encryption_config_defaults() {
435        let config = EncryptionConfig::new();
436        assert!(config.enabled);
437        assert!(config.cache_ceks);
438        assert!(!config.is_ready()); // No providers
439    }
440
441    #[test]
442    fn test_result_set_encryption_info() {
443        let cek_table = CekTable::new();
444        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
445
446        assert!(!info.is_column_encrypted(0));
447        assert!(!info.is_column_encrypted(1));
448        assert!(!info.is_column_encrypted(2));
449
450        let metadata = CryptoMetadata {
451            cek_table_ordinal: 0,
452            base_user_type: 0,
453            base_col_type: 0x26,
454            base_type_info: tds_protocol::token::TypeInfo::default(),
455            algorithm_id: 2,
456            encryption_type: EncryptionTypeWire::Deterministic,
457            normalization_version: 1,
458        };
459
460        info.set_column_crypto(1, metadata);
461        assert!(!info.is_column_encrypted(0));
462        assert!(info.is_column_encrypted(1));
463        assert!(!info.is_column_encrypted(2));
464
465        assert_eq!(
466            info.get_encryption_type(1),
467            Some(EncryptionTypeWire::Deterministic)
468        );
469    }
470
471    #[test]
472    fn test_parameter_encryption_info() {
473        let mut info = ParameterEncryptionInfo::new();
474
475        assert!(!info.needs_encryption("@p1"));
476
477        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
478        info.add_parameter("@p1".to_string(), crypto);
479
480        assert!(info.needs_encryption("@p1"));
481        assert!(!info.needs_encryption("@p2"));
482
483        let param = info.get_parameter("@p1").unwrap();
484        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
485    }
486}