Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,ignore
23//! use mssql_client::{Config, EncryptionConfig};
24//! use mssql_auth::InMemoryKeyStore;
25//!
26//! // Create encryption configuration
27//! let mut key_store = InMemoryKeyStore::new();
28//! key_store.add_key("MyKey", &pem)?;
29//!
30//! let encryption_config = EncryptionConfig::new()
31//!     .with_provider(key_store)
32//!     .build();
33//!
34//! // Connect with encryption enabled
35//! let config = Config::from_connection_string(conn_str)?
36//!     .with_encryption(encryption_config);
37//!
38//! let client = Client::connect(config).await?;
39//! ```
40//!
41//! ## Security Model
42//!
43//! - **Client-only decryption**: SQL Server never sees plaintext data
44//! - **DBA protection**: Even database administrators cannot read encrypted data
45//! - **Key separation**: CMK stays in secure key store, never transmitted
46
47use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57/// Configuration for Always Encrypted feature.
58#[derive(Default)]
59pub struct EncryptionConfig {
60    /// Whether encryption is enabled.
61    pub enabled: bool,
62    /// Registered key store providers.
63    providers: Vec<Box<dyn KeyStoreProvider>>,
64    /// Whether to cache decrypted CEKs for performance.
65    pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69    /// Create a new encryption configuration (disabled by default).
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            enabled: true,
74            providers: Vec::new(),
75            cache_ceks: true,
76        }
77    }
78
79    /// Register a key store provider.
80    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81        self.providers.push(Box::new(provider));
82    }
83
84    /// Builder method to add a key store provider.
85    #[must_use]
86    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87        self.register_provider(provider);
88        self
89    }
90
91    /// Enable or disable CEK caching.
92    #[must_use]
93    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94        self.cache_ceks = enabled;
95        self
96    }
97
98    /// Get a provider by name.
99    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100        self.providers
101            .iter()
102            .find(|p| p.provider_name() == name)
103            .map(|p| p.as_ref())
104    }
105
106    /// Check if encryption is ready (enabled and has providers).
107    #[must_use]
108    pub fn is_ready(&self) -> bool {
109        self.enabled && !self.providers.is_empty()
110    }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("EncryptionConfig")
116            .field("enabled", &self.enabled)
117            .field("provider_count", &self.providers.len())
118            .field("cache_ceks", &self.cache_ceks)
119            .finish()
120    }
121}
122
123/// Runtime context for encryption operations.
124///
125/// This is the active encryption state for a connected client,
126/// including resolved CEKs and encryptors.
127#[cfg(feature = "always-encrypted")]
128pub struct EncryptionContext {
129    /// Key store providers by name.
130    providers: HashMap<String, Box<dyn KeyStoreProvider>>,
131    /// Cache for decrypted CEKs.
132    cek_cache: CekCache,
133    /// Whether caching is enabled.
134    cache_enabled: bool,
135}
136
137#[cfg(feature = "always-encrypted")]
138impl EncryptionContext {
139    /// Create a new encryption context from configuration.
140    pub fn new(config: EncryptionConfig) -> Self {
141        let providers = config
142            .providers
143            .into_iter()
144            .map(|p| (p.provider_name().to_string(), p))
145            .collect();
146
147        Self {
148            providers,
149            cek_cache: CekCache::new(),
150            cache_enabled: config.cache_ceks,
151        }
152    }
153
154    /// Get or decrypt a CEK for a column.
155    ///
156    /// This handles the CEK caching and decryption logic:
157    /// 1. Check cache for existing encryptor
158    /// 2. If not cached, decrypt CEK using the appropriate key store
159    /// 3. Create and cache the encryptor
160    pub async fn get_encryptor(
161        &self,
162        cek_entry: &CekTableEntry,
163    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
164        let cache_key = CekCacheKey::new(
165            cek_entry.database_id,
166            cek_entry.cek_id,
167            cek_entry.cek_version,
168        );
169
170        // Check cache first
171        if self.cache_enabled {
172            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
173                return Ok(encryptor);
174            }
175        }
176
177        // Get the primary CEK value
178        let cek_value = cek_entry
179            .primary_value()
180            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
181
182        // Find the appropriate key store provider
183        let provider = self
184            .providers
185            .get(&cek_value.key_store_provider_name)
186            .ok_or_else(|| {
187                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
188            })?;
189
190        // Decrypt the CEK
191        let decrypted_cek = provider
192            .decrypt_cek(
193                &cek_value.cmk_path,
194                &cek_value.encryption_algorithm,
195                &cek_value.encrypted_value,
196            )
197            .await?;
198
199        // Create encryptor and cache it
200        if self.cache_enabled {
201            self.cek_cache.insert(cache_key, decrypted_cek)
202        } else {
203            // Create encryptor without caching
204            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
205        }
206    }
207
208    /// Encrypt a value for a column.
209    ///
210    /// # Arguments
211    ///
212    /// * `plaintext` - The plaintext value to encrypt
213    /// * `cek_entry` - The CEK table entry for this column
214    /// * `encryption_type` - Deterministic or randomized encryption
215    pub async fn encrypt_value(
216        &self,
217        plaintext: &[u8],
218        cek_entry: &CekTableEntry,
219        encryption_type: EncryptionTypeWire,
220    ) -> Result<Vec<u8>, EncryptionError> {
221        let encryptor = self.get_encryptor(cek_entry).await?;
222
223        let enc_type = match encryption_type {
224            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
225            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
226            _ => {
227                return Err(EncryptionError::UnsupportedOperation(format!(
228                    "unsupported encryption type: {encryption_type:?}"
229                )));
230            }
231        };
232
233        encryptor.encrypt(plaintext, enc_type)
234    }
235
236    /// Decrypt a value from an encrypted column.
237    ///
238    /// # Arguments
239    ///
240    /// * `ciphertext` - The encrypted value
241    /// * `cek_entry` - The CEK table entry for this column
242    pub async fn decrypt_value(
243        &self,
244        ciphertext: &[u8],
245        cek_entry: &CekTableEntry,
246    ) -> Result<Vec<u8>, EncryptionError> {
247        let encryptor = self.get_encryptor(cek_entry).await?;
248        encryptor.decrypt(ciphertext)
249    }
250
251    /// Clear the CEK cache.
252    ///
253    /// Call this when keys may have been rotated.
254    pub fn clear_cache(&self) {
255        self.cek_cache.clear();
256    }
257
258    /// Check if a provider is registered.
259    pub fn has_provider(&self, name: &str) -> bool {
260        self.providers.contains_key(name)
261    }
262}
263
264#[cfg(feature = "always-encrypted")]
265impl std::fmt::Debug for EncryptionContext {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("EncryptionContext")
268            .field("providers", &self.providers.keys().collect::<Vec<_>>())
269            .field("cache_entries", &self.cek_cache.len())
270            .field("cache_enabled", &self.cache_enabled)
271            .finish()
272    }
273}
274
275/// Column encryption metadata for a result set.
276///
277/// This combines the CEK table with per-column crypto metadata,
278/// providing all information needed to decrypt result columns.
279#[derive(Debug, Clone)]
280pub struct ResultSetEncryptionInfo {
281    /// CEK table for this result set.
282    pub cek_table: CekTable,
283    /// Crypto metadata for each column (index matches column ordinal).
284    pub column_crypto: Vec<Option<CryptoMetadata>>,
285}
286
287impl ResultSetEncryptionInfo {
288    /// Create encryption info for a result set.
289    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
290        Self {
291            cek_table,
292            column_crypto: vec![None; column_count],
293        }
294    }
295
296    /// Set crypto metadata for a column.
297    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
298        if ordinal < self.column_crypto.len() {
299            self.column_crypto[ordinal] = Some(metadata);
300        }
301    }
302
303    /// Get the CEK entry for a column.
304    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
305        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
306        self.cek_table.get(crypto.cek_table_ordinal)
307    }
308
309    /// Check if a column is encrypted.
310    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
311        self.column_crypto
312            .get(ordinal)
313            .map(|c| c.is_some())
314            .unwrap_or(false)
315    }
316
317    /// Get the encryption type for a column.
318    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
319        self.column_crypto
320            .get(ordinal)?
321            .as_ref()
322            .map(|c| c.encryption_type)
323    }
324}
325
326/// Parameter encryption metadata for a query.
327///
328/// This is returned by `sp_describe_parameter_encryption` and describes
329/// how each parameter should be encrypted.
330#[derive(Debug, Clone)]
331pub struct ParameterEncryptionInfo {
332    /// CEK table for parameters.
333    pub cek_table: CekTable,
334    /// Mapping from parameter name to crypto metadata.
335    pub parameters: HashMap<String, ParameterCryptoInfo>,
336}
337
338impl ParameterEncryptionInfo {
339    /// Create empty parameter encryption info.
340    pub fn new() -> Self {
341        Self {
342            cek_table: CekTable::new(),
343            parameters: HashMap::new(),
344        }
345    }
346
347    /// Add encryption info for a parameter.
348    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
349        self.parameters.insert(name, info);
350    }
351
352    /// Get encryption info for a parameter.
353    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
354        self.parameters.get(name)
355    }
356
357    /// Check if a parameter needs encryption.
358    pub fn needs_encryption(&self, name: &str) -> bool {
359        self.parameters.contains_key(name)
360    }
361}
362
363impl Default for ParameterEncryptionInfo {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369/// Encryption metadata for a single parameter.
370#[derive(Debug, Clone)]
371pub struct ParameterCryptoInfo {
372    /// Index into the CEK table.
373    pub cek_ordinal: u16,
374    /// Encryption type (deterministic or randomized).
375    pub encryption_type: EncryptionTypeWire,
376    /// Algorithm ID.
377    pub algorithm_id: u8,
378    /// Target column ordinal in the table (for type information).
379    pub column_ordinal: u16,
380    /// Target column database ID.
381    pub database_id: u32,
382}
383
384impl ParameterCryptoInfo {
385    /// Create new parameter crypto info.
386    pub fn new(
387        cek_ordinal: u16,
388        encryption_type: EncryptionTypeWire,
389        algorithm_id: u8,
390        column_ordinal: u16,
391        database_id: u32,
392    ) -> Self {
393        Self {
394            cek_ordinal,
395            encryption_type,
396            algorithm_id,
397            column_ordinal,
398            database_id,
399        }
400    }
401}
402
403#[cfg(test)]
404#[allow(clippy::unwrap_used, clippy::expect_used)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_encryption_config_defaults() {
410        let config = EncryptionConfig::new();
411        assert!(config.enabled);
412        assert!(config.cache_ceks);
413        assert!(!config.is_ready()); // No providers
414    }
415
416    #[test]
417    fn test_result_set_encryption_info() {
418        let cek_table = CekTable::new();
419        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
420
421        assert!(!info.is_column_encrypted(0));
422        assert!(!info.is_column_encrypted(1));
423        assert!(!info.is_column_encrypted(2));
424
425        let metadata = CryptoMetadata {
426            cek_table_ordinal: 0,
427            algorithm_id: 2,
428            encryption_type: EncryptionTypeWire::Deterministic,
429            normalization_version: 1,
430        };
431
432        info.set_column_crypto(1, metadata);
433        assert!(!info.is_column_encrypted(0));
434        assert!(info.is_column_encrypted(1));
435        assert!(!info.is_column_encrypted(2));
436
437        assert_eq!(
438            info.get_encryption_type(1),
439            Some(EncryptionTypeWire::Deterministic)
440        );
441    }
442
443    #[test]
444    fn test_parameter_encryption_info() {
445        let mut info = ParameterEncryptionInfo::new();
446
447        assert!(!info.needs_encryption("@p1"));
448
449        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
450        info.add_parameter("@p1".to_string(), crypto);
451
452        assert!(info.needs_encryption("@p1"));
453        assert!(!info.needs_encryption("@p2"));
454
455        let param = info.get_parameter("@p1").unwrap();
456        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
457    }
458}