mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,ignore
23//! use mssql_client::{Config, EncryptionConfig};
24//! use mssql_auth::InMemoryKeyStore;
25//!
26//! // Create encryption configuration
27//! let mut key_store = InMemoryKeyStore::new();
28//! key_store.add_key("MyKey", &pem)?;
29//!
30//! let encryption_config = EncryptionConfig::new()
31//!     .with_provider(key_store)
32//!     .build();
33//!
34//! // Connect with encryption enabled
35//! let config = Config::from_connection_string(conn_str)?
36//!     .with_encryption(encryption_config);
37//!
38//! let client = Client::connect(config).await?;
39//! ```
40//!
41//! ## Security Model
42//!
43//! - **Client-only decryption**: SQL Server never sees plaintext data
44//! - **DBA protection**: Even database administrators cannot read encrypted data
45//! - **Key separation**: CMK stays in secure key store, never transmitted
46
47use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57/// Configuration for Always Encrypted feature.
58#[derive(Default)]
59pub struct EncryptionConfig {
60    /// Whether encryption is enabled.
61    pub enabled: bool,
62    /// Registered key store providers.
63    providers: Vec<Box<dyn KeyStoreProvider>>,
64    /// Whether to cache decrypted CEKs for performance.
65    pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69    /// Create a new encryption configuration (disabled by default).
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            enabled: true,
74            providers: Vec::new(),
75            cache_ceks: true,
76        }
77    }
78
79    /// Register a key store provider.
80    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81        self.providers.push(Box::new(provider));
82    }
83
84    /// Builder method to add a key store provider.
85    #[must_use]
86    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87        self.register_provider(provider);
88        self
89    }
90
91    /// Enable or disable CEK caching.
92    #[must_use]
93    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94        self.cache_ceks = enabled;
95        self
96    }
97
98    /// Get a provider by name.
99    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100        self.providers
101            .iter()
102            .find(|p| p.provider_name() == name)
103            .map(|p| p.as_ref())
104    }
105
106    /// Check if encryption is ready (enabled and has providers).
107    #[must_use]
108    pub fn is_ready(&self) -> bool {
109        self.enabled && !self.providers.is_empty()
110    }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("EncryptionConfig")
116            .field("enabled", &self.enabled)
117            .field("provider_count", &self.providers.len())
118            .field("cache_ceks", &self.cache_ceks)
119            .finish()
120    }
121}
122
123/// Runtime context for encryption operations.
124///
125/// This is the active encryption state for a connected client,
126/// including resolved CEKs and encryptors.
127#[cfg(feature = "always-encrypted")]
128pub struct EncryptionContext {
129    /// Key store providers by name.
130    providers: HashMap<String, Box<dyn KeyStoreProvider>>,
131    /// Cache for decrypted CEKs.
132    cek_cache: CekCache,
133    /// Whether caching is enabled.
134    cache_enabled: bool,
135}
136
137#[cfg(feature = "always-encrypted")]
138impl EncryptionContext {
139    /// Create a new encryption context from configuration.
140    pub fn new(config: EncryptionConfig) -> Self {
141        let providers = config
142            .providers
143            .into_iter()
144            .map(|p| (p.provider_name().to_string(), p))
145            .collect();
146
147        Self {
148            providers,
149            cek_cache: CekCache::new(),
150            cache_enabled: config.cache_ceks,
151        }
152    }
153
154    /// Get or decrypt a CEK for a column.
155    ///
156    /// This handles the CEK caching and decryption logic:
157    /// 1. Check cache for existing encryptor
158    /// 2. If not cached, decrypt CEK using the appropriate key store
159    /// 3. Create and cache the encryptor
160    pub async fn get_encryptor(
161        &self,
162        cek_entry: &CekTableEntry,
163    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
164        let cache_key = CekCacheKey::new(
165            cek_entry.database_id,
166            cek_entry.cek_id,
167            cek_entry.cek_version,
168        );
169
170        // Check cache first
171        if self.cache_enabled {
172            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
173                return Ok(encryptor);
174            }
175        }
176
177        // Get the primary CEK value
178        let cek_value = cek_entry
179            .primary_value()
180            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
181
182        // Find the appropriate key store provider
183        let provider = self
184            .providers
185            .get(&cek_value.key_store_provider_name)
186            .ok_or_else(|| {
187                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
188            })?;
189
190        // Decrypt the CEK
191        let decrypted_cek = provider
192            .decrypt_cek(
193                &cek_value.cmk_path,
194                &cek_value.encryption_algorithm,
195                &cek_value.encrypted_value,
196            )
197            .await?;
198
199        // Create encryptor and cache it
200        if self.cache_enabled {
201            self.cek_cache.insert(cache_key, decrypted_cek)
202        } else {
203            // Create encryptor without caching
204            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
205        }
206    }
207
208    /// Encrypt a value for a column.
209    ///
210    /// # Arguments
211    ///
212    /// * `plaintext` - The plaintext value to encrypt
213    /// * `cek_entry` - The CEK table entry for this column
214    /// * `encryption_type` - Deterministic or randomized encryption
215    pub async fn encrypt_value(
216        &self,
217        plaintext: &[u8],
218        cek_entry: &CekTableEntry,
219        encryption_type: EncryptionTypeWire,
220    ) -> Result<Vec<u8>, EncryptionError> {
221        let encryptor = self.get_encryptor(cek_entry).await?;
222
223        let enc_type = match encryption_type {
224            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
225            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
226        };
227
228        encryptor.encrypt(plaintext, enc_type)
229    }
230
231    /// Decrypt a value from an encrypted column.
232    ///
233    /// # Arguments
234    ///
235    /// * `ciphertext` - The encrypted value
236    /// * `cek_entry` - The CEK table entry for this column
237    pub async fn decrypt_value(
238        &self,
239        ciphertext: &[u8],
240        cek_entry: &CekTableEntry,
241    ) -> Result<Vec<u8>, EncryptionError> {
242        let encryptor = self.get_encryptor(cek_entry).await?;
243        encryptor.decrypt(ciphertext)
244    }
245
246    /// Clear the CEK cache.
247    ///
248    /// Call this when keys may have been rotated.
249    pub fn clear_cache(&self) {
250        self.cek_cache.clear();
251    }
252
253    /// Check if a provider is registered.
254    pub fn has_provider(&self, name: &str) -> bool {
255        self.providers.contains_key(name)
256    }
257}
258
259#[cfg(feature = "always-encrypted")]
260impl std::fmt::Debug for EncryptionContext {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.debug_struct("EncryptionContext")
263            .field("providers", &self.providers.keys().collect::<Vec<_>>())
264            .field("cache_entries", &self.cek_cache.len())
265            .field("cache_enabled", &self.cache_enabled)
266            .finish()
267    }
268}
269
270/// Column encryption metadata for a result set.
271///
272/// This combines the CEK table with per-column crypto metadata,
273/// providing all information needed to decrypt result columns.
274#[derive(Debug, Clone)]
275pub struct ResultSetEncryptionInfo {
276    /// CEK table for this result set.
277    pub cek_table: CekTable,
278    /// Crypto metadata for each column (index matches column ordinal).
279    pub column_crypto: Vec<Option<CryptoMetadata>>,
280}
281
282impl ResultSetEncryptionInfo {
283    /// Create encryption info for a result set.
284    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
285        Self {
286            cek_table,
287            column_crypto: vec![None; column_count],
288        }
289    }
290
291    /// Set crypto metadata for a column.
292    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
293        if ordinal < self.column_crypto.len() {
294            self.column_crypto[ordinal] = Some(metadata);
295        }
296    }
297
298    /// Get the CEK entry for a column.
299    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
300        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
301        self.cek_table.get(crypto.cek_table_ordinal)
302    }
303
304    /// Check if a column is encrypted.
305    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
306        self.column_crypto
307            .get(ordinal)
308            .map(|c| c.is_some())
309            .unwrap_or(false)
310    }
311
312    /// Get the encryption type for a column.
313    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
314        self.column_crypto
315            .get(ordinal)?
316            .as_ref()
317            .map(|c| c.encryption_type)
318    }
319}
320
321/// Parameter encryption metadata for a query.
322///
323/// This is returned by `sp_describe_parameter_encryption` and describes
324/// how each parameter should be encrypted.
325#[derive(Debug, Clone)]
326pub struct ParameterEncryptionInfo {
327    /// CEK table for parameters.
328    pub cek_table: CekTable,
329    /// Mapping from parameter name to crypto metadata.
330    pub parameters: HashMap<String, ParameterCryptoInfo>,
331}
332
333impl ParameterEncryptionInfo {
334    /// Create empty parameter encryption info.
335    pub fn new() -> Self {
336        Self {
337            cek_table: CekTable::new(),
338            parameters: HashMap::new(),
339        }
340    }
341
342    /// Add encryption info for a parameter.
343    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
344        self.parameters.insert(name, info);
345    }
346
347    /// Get encryption info for a parameter.
348    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
349        self.parameters.get(name)
350    }
351
352    /// Check if a parameter needs encryption.
353    pub fn needs_encryption(&self, name: &str) -> bool {
354        self.parameters.contains_key(name)
355    }
356}
357
358impl Default for ParameterEncryptionInfo {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364/// Encryption metadata for a single parameter.
365#[derive(Debug, Clone)]
366pub struct ParameterCryptoInfo {
367    /// Index into the CEK table.
368    pub cek_ordinal: u16,
369    /// Encryption type (deterministic or randomized).
370    pub encryption_type: EncryptionTypeWire,
371    /// Algorithm ID.
372    pub algorithm_id: u8,
373    /// Target column ordinal in the table (for type information).
374    pub column_ordinal: u16,
375    /// Target column database ID.
376    pub database_id: u32,
377}
378
379impl ParameterCryptoInfo {
380    /// Create new parameter crypto info.
381    pub fn new(
382        cek_ordinal: u16,
383        encryption_type: EncryptionTypeWire,
384        algorithm_id: u8,
385        column_ordinal: u16,
386        database_id: u32,
387    ) -> Self {
388        Self {
389            cek_ordinal,
390            encryption_type,
391            algorithm_id,
392            column_ordinal,
393            database_id,
394        }
395    }
396}
397
398#[cfg(test)]
399#[allow(clippy::unwrap_used, clippy::expect_used)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_encryption_config_defaults() {
405        let config = EncryptionConfig::new();
406        assert!(config.enabled);
407        assert!(config.cache_ceks);
408        assert!(!config.is_ready()); // No providers
409    }
410
411    #[test]
412    fn test_result_set_encryption_info() {
413        let cek_table = CekTable::new();
414        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
415
416        assert!(!info.is_column_encrypted(0));
417        assert!(!info.is_column_encrypted(1));
418        assert!(!info.is_column_encrypted(2));
419
420        let metadata = CryptoMetadata {
421            cek_table_ordinal: 0,
422            algorithm_id: 2,
423            encryption_type: EncryptionTypeWire::Deterministic,
424            normalization_version: 1,
425        };
426
427        info.set_column_crypto(1, metadata);
428        assert!(!info.is_column_encrypted(0));
429        assert!(info.is_column_encrypted(1));
430        assert!(!info.is_column_encrypted(2));
431
432        assert_eq!(
433            info.get_encryption_type(1),
434            Some(EncryptionTypeWire::Deterministic)
435        );
436    }
437
438    #[test]
439    fn test_parameter_encryption_info() {
440        let mut info = ParameterEncryptionInfo::new();
441
442        assert!(!info.needs_encryption("@p1"));
443
444        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
445        info.add_parameter("@p1".to_string(), crypto);
446
447        assert!(info.needs_encryption("@p1"));
448        assert!(!info.needs_encryption("@p2"));
449
450        let param = info.get_parameter("@p1").unwrap();
451        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
452    }
453}