Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,ignore
23//! use mssql_client::{Config, EncryptionConfig};
24//! use mssql_auth::InMemoryKeyStore;
25//!
26//! // Create encryption configuration
27//! let mut key_store = InMemoryKeyStore::new();
28//! key_store.add_key("MyKey", &pem)?;
29//!
30//! let encryption_config = EncryptionConfig::new()
31//!     .with_provider(key_store)
32//!     .build();
33//!
34//! // Connect with encryption enabled
35//! let config = Config::from_connection_string(conn_str)?
36//!     .with_encryption(encryption_config);
37//!
38//! let client = Client::connect(config).await?;
39//! ```
40//!
41//! ## Security Model
42//!
43//! - **Client-only decryption**: SQL Server never sees plaintext data
44//! - **DBA protection**: Even database administrators cannot read encrypted data
45//! - **Key separation**: CMK stays in secure key store, never transmitted
46
47use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57/// Configuration for Always Encrypted feature.
58#[derive(Default)]
59pub struct EncryptionConfig {
60    /// Whether encryption is enabled.
61    pub enabled: bool,
62    /// Registered key store providers.
63    providers: Vec<Box<dyn KeyStoreProvider>>,
64    /// Whether to cache decrypted CEKs for performance.
65    pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69    /// Create a new encryption configuration (disabled by default).
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            enabled: true,
74            providers: Vec::new(),
75            cache_ceks: true,
76        }
77    }
78
79    /// Register a key store provider.
80    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81        self.providers.push(Box::new(provider));
82    }
83
84    /// Builder method to add a key store provider.
85    #[must_use]
86    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87        self.register_provider(provider);
88        self
89    }
90
91    /// Enable or disable CEK caching.
92    #[must_use]
93    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94        self.cache_ceks = enabled;
95        self
96    }
97
98    /// Get a provider by name.
99    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100        self.providers
101            .iter()
102            .find(|p| p.provider_name() == name)
103            .map(|p| p.as_ref())
104    }
105
106    /// Check if encryption is ready (enabled and has providers).
107    #[must_use]
108    pub fn is_ready(&self) -> bool {
109        self.enabled && !self.providers.is_empty()
110    }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("EncryptionConfig")
116            .field("enabled", &self.enabled)
117            .field("provider_count", &self.providers.len())
118            .field("cache_ceks", &self.cache_ceks)
119            .finish()
120    }
121}
122
123/// Runtime context for encryption operations.
124///
125/// This is the active encryption state for a connected client,
126/// including resolved CEKs and encryptors.
127///
128/// The context holds an `Arc<EncryptionConfig>` so providers remain accessible
129/// across connection retries/redirects where the `Config` (and its inner
130/// encryption config Arc) gets cloned multiple times.
131#[cfg(feature = "always-encrypted")]
132pub struct EncryptionContext {
133    /// Shared handle on the user-supplied configuration. Providers are looked
134    /// up by name through this reference, so an arbitrary number of `Arc`
135    /// clones do not lose access to them.
136    config: std::sync::Arc<EncryptionConfig>,
137    /// Cache for decrypted CEKs.
138    cek_cache: CekCache,
139    /// Whether caching is enabled.
140    cache_enabled: bool,
141}
142
143#[cfg(feature = "always-encrypted")]
144impl EncryptionContext {
145    /// Create a new encryption context from an Arc-wrapped configuration.
146    ///
147    /// The Arc is retained by the context so provider lookups continue to
148    /// work for the lifetime of the client — regardless of how many times
149    /// the outer `Config` has been cloned for retry/redirect handling.
150    pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
151        let cache_enabled = config.cache_ceks;
152        Self {
153            config,
154            cek_cache: CekCache::new(),
155            cache_enabled,
156        }
157    }
158
159    /// Create a new encryption context from configuration.
160    pub fn new(config: EncryptionConfig) -> Self {
161        Self::from_arc(std::sync::Arc::new(config))
162    }
163
164    /// Get or decrypt a CEK for a column.
165    ///
166    /// This handles the CEK caching and decryption logic:
167    /// 1. Check cache for existing encryptor
168    /// 2. If not cached, decrypt CEK using the appropriate key store
169    /// 3. Create and cache the encryptor
170    pub async fn get_encryptor(
171        &self,
172        cek_entry: &CekTableEntry,
173    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
174        let cache_key = CekCacheKey::new(
175            cek_entry.database_id,
176            cek_entry.cek_id,
177            cek_entry.cek_version,
178        );
179
180        // Check cache first
181        if self.cache_enabled {
182            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
183                return Ok(encryptor);
184            }
185        }
186
187        // Get the primary CEK value
188        let cek_value = cek_entry
189            .primary_value()
190            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
191
192        // Find the appropriate key store provider via the shared config
193        let provider = self
194            .config
195            .get_provider(&cek_value.key_store_provider_name)
196            .ok_or_else(|| {
197                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
198            })?;
199
200        // Decrypt the CEK
201        let decrypted_cek = provider
202            .decrypt_cek(
203                &cek_value.cmk_path,
204                &cek_value.encryption_algorithm,
205                &cek_value.encrypted_value,
206            )
207            .await?;
208
209        // Create encryptor and cache it
210        if self.cache_enabled {
211            self.cek_cache.insert(cache_key, decrypted_cek)
212        } else {
213            // Create encryptor without caching
214            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
215        }
216    }
217
218    /// Encrypt a value for a column.
219    ///
220    /// # Arguments
221    ///
222    /// * `plaintext` - The plaintext value to encrypt
223    /// * `cek_entry` - The CEK table entry for this column
224    /// * `encryption_type` - Deterministic or randomized encryption
225    pub async fn encrypt_value(
226        &self,
227        plaintext: &[u8],
228        cek_entry: &CekTableEntry,
229        encryption_type: EncryptionTypeWire,
230    ) -> Result<Vec<u8>, EncryptionError> {
231        let encryptor = self.get_encryptor(cek_entry).await?;
232
233        let enc_type = match encryption_type {
234            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
235            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
236            _ => {
237                return Err(EncryptionError::UnsupportedOperation(format!(
238                    "unsupported encryption type: {encryption_type:?}"
239                )));
240            }
241        };
242
243        encryptor.encrypt(plaintext, enc_type)
244    }
245
246    /// Decrypt a value from an encrypted column.
247    ///
248    /// # Arguments
249    ///
250    /// * `ciphertext` - The encrypted value
251    /// * `cek_entry` - The CEK table entry for this column
252    pub async fn decrypt_value(
253        &self,
254        ciphertext: &[u8],
255        cek_entry: &CekTableEntry,
256    ) -> Result<Vec<u8>, EncryptionError> {
257        let encryptor = self.get_encryptor(cek_entry).await?;
258        encryptor.decrypt(ciphertext)
259    }
260
261    /// Clear the CEK cache.
262    ///
263    /// Call this when keys may have been rotated.
264    pub fn clear_cache(&self) {
265        self.cek_cache.clear();
266    }
267
268    /// Check if a provider is registered.
269    pub fn has_provider(&self, name: &str) -> bool {
270        self.config.get_provider(name).is_some()
271    }
272}
273
274#[cfg(feature = "always-encrypted")]
275impl std::fmt::Debug for EncryptionContext {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.debug_struct("EncryptionContext")
278            .field("provider_count", &self.config.providers.len())
279            .field("cache_entries", &self.cek_cache.len())
280            .field("cache_enabled", &self.cache_enabled)
281            .finish()
282    }
283}
284
285/// Column encryption metadata for a result set.
286///
287/// This combines the CEK table with per-column crypto metadata,
288/// providing all information needed to decrypt result columns.
289#[derive(Debug, Clone)]
290pub struct ResultSetEncryptionInfo {
291    /// CEK table for this result set.
292    pub cek_table: CekTable,
293    /// Crypto metadata for each column (index matches column ordinal).
294    pub column_crypto: Vec<Option<CryptoMetadata>>,
295}
296
297impl ResultSetEncryptionInfo {
298    /// Create encryption info for a result set.
299    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
300        Self {
301            cek_table,
302            column_crypto: vec![None; column_count],
303        }
304    }
305
306    /// Set crypto metadata for a column.
307    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
308        if ordinal < self.column_crypto.len() {
309            self.column_crypto[ordinal] = Some(metadata);
310        }
311    }
312
313    /// Get the CEK entry for a column.
314    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
315        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
316        self.cek_table.get(crypto.cek_table_ordinal)
317    }
318
319    /// Check if a column is encrypted.
320    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
321        self.column_crypto
322            .get(ordinal)
323            .map(|c| c.is_some())
324            .unwrap_or(false)
325    }
326
327    /// Get the encryption type for a column.
328    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
329        self.column_crypto
330            .get(ordinal)?
331            .as_ref()
332            .map(|c| c.encryption_type)
333    }
334}
335
336/// Parameter encryption metadata for a query.
337///
338/// This is returned by `sp_describe_parameter_encryption` and describes
339/// how each parameter should be encrypted.
340#[derive(Debug, Clone)]
341pub struct ParameterEncryptionInfo {
342    /// CEK table for parameters.
343    pub cek_table: CekTable,
344    /// Mapping from parameter name to crypto metadata.
345    pub parameters: HashMap<String, ParameterCryptoInfo>,
346}
347
348impl ParameterEncryptionInfo {
349    /// Create empty parameter encryption info.
350    pub fn new() -> Self {
351        Self {
352            cek_table: CekTable::new(),
353            parameters: HashMap::new(),
354        }
355    }
356
357    /// Add encryption info for a parameter.
358    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
359        self.parameters.insert(name, info);
360    }
361
362    /// Get encryption info for a parameter.
363    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
364        self.parameters.get(name)
365    }
366
367    /// Check if a parameter needs encryption.
368    pub fn needs_encryption(&self, name: &str) -> bool {
369        self.parameters.contains_key(name)
370    }
371}
372
373impl Default for ParameterEncryptionInfo {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379/// Encryption metadata for a single parameter.
380#[derive(Debug, Clone)]
381pub struct ParameterCryptoInfo {
382    /// Index into the CEK table.
383    pub cek_ordinal: u16,
384    /// Encryption type (deterministic or randomized).
385    pub encryption_type: EncryptionTypeWire,
386    /// Algorithm ID.
387    pub algorithm_id: u8,
388    /// Target column ordinal in the table (for type information).
389    pub column_ordinal: u16,
390    /// Target column database ID.
391    pub database_id: u32,
392}
393
394impl ParameterCryptoInfo {
395    /// Create new parameter crypto info.
396    pub fn new(
397        cek_ordinal: u16,
398        encryption_type: EncryptionTypeWire,
399        algorithm_id: u8,
400        column_ordinal: u16,
401        database_id: u32,
402    ) -> Self {
403        Self {
404            cek_ordinal,
405            encryption_type,
406            algorithm_id,
407            column_ordinal,
408            database_id,
409        }
410    }
411}
412
413#[cfg(test)]
414#[allow(clippy::unwrap_used, clippy::expect_used)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_encryption_config_defaults() {
420        let config = EncryptionConfig::new();
421        assert!(config.enabled);
422        assert!(config.cache_ceks);
423        assert!(!config.is_ready()); // No providers
424    }
425
426    #[test]
427    fn test_result_set_encryption_info() {
428        let cek_table = CekTable::new();
429        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
430
431        assert!(!info.is_column_encrypted(0));
432        assert!(!info.is_column_encrypted(1));
433        assert!(!info.is_column_encrypted(2));
434
435        let metadata = CryptoMetadata {
436            cek_table_ordinal: 0,
437            base_user_type: 0,
438            base_col_type: 0x26,
439            base_type_info: tds_protocol::token::TypeInfo::default(),
440            algorithm_id: 2,
441            encryption_type: EncryptionTypeWire::Deterministic,
442            normalization_version: 1,
443        };
444
445        info.set_column_crypto(1, metadata);
446        assert!(!info.is_column_encrypted(0));
447        assert!(info.is_column_encrypted(1));
448        assert!(!info.is_column_encrypted(2));
449
450        assert_eq!(
451            info.get_encryption_type(1),
452            Some(EncryptionTypeWire::Deterministic)
453        );
454    }
455
456    #[test]
457    fn test_parameter_encryption_info() {
458        let mut info = ParameterEncryptionInfo::new();
459
460        assert!(!info.needs_encryption("@p1"));
461
462        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
463        info.add_parameter("@p1".to_string(), crypto);
464
465        assert!(info.needs_encryption("@p1"));
466        assert!(!info.needs_encryption("@p2"));
467
468        let param = info.get_parameter("@p1").unwrap();
469        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
470    }
471}