Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,no_run
23//! # async fn with_always_encrypted() -> Result<(), Box<dyn std::error::Error>> {
24//! # #[cfg(feature = "always-encrypted")]
25//! # {
26//! # let conn_str = "Server=localhost;Database=db;Encrypt=strict;Column Encryption Setting=Enabled";
27//! use mssql_client::{Client, Config, EncryptionConfig};
28//! use mssql_auth::InMemoryKeyStore;
29//!
30//! // Register a key-store provider, then attach it to the connection config.
31//! let key_store = InMemoryKeyStore::new();
32//! let encryption_config = EncryptionConfig::new().with_provider(key_store);
33//!
34//! let config = Config::from_connection_string(conn_str)?
35//!     .with_column_encryption(encryption_config);
36//!
37//! let _client = Client::connect(config).await?;
38//! # }
39//! # Ok(())
40//! # }
41//! ```
42//!
43//! Equivalently, set `Column Encryption Setting=Enabled` in the connection
44//! string. Production-ready providers ship in `mssql-auth`: `InMemoryKeyStore`
45//! (dev/test), `AzureKeyVaultProvider` (`azure-identity` feature), and
46//! `WindowsCertStoreProvider` (`sspi-auth`, Windows). Implement
47//! [`mssql_auth::KeyStoreProvider`] for custom key storage. Do **not** substitute
48//! T-SQL `ENCRYPTBYKEY` — the server can see that plaintext, defeating the point.
49//!
50//! ## How decryption works
51//!
52//! 1. Always Encrypted support is negotiated in LOGIN7 (`FEATURE_EXT`).
53//! 2. `ColMetaData` carries [`CryptoMetadata`] and the [`CekTable`]; column
54//!    encryption keys are resolved asynchronously up front (calling the key-store
55//!    providers).
56//! 3. Each encrypted cell is decrypted during row parsing via
57//!    AEAD_AES_256_CBC_HMAC_SHA256, with the HMAC verified before decryption.
58//!
59//! Reads are transparent across `query`, `call_procedure`, the procedure
60//! builder, and multi-result queries. Parameter (write) encryption is wired
61//! into parameterized `query`/`execute` for the common scalar types — `int`,
62//! `tinyint`, `smallint`, `bigint`, `bit`, `real`, `float`, `nvarchar`,
63//! `varbinary`, `uniqueidentifier`, `date`, `money`, `smallmoney`, `decimal`
64//! (via `numeric(value, precision, scale)`), and typed `NULL` (via
65//! `null::<T>()`): with `Column Encryption Setting=Enabled` the
66//! driver describes the parameters (`sp_describe_parameter_encryption`),
67//! encrypts those bound to encrypted columns client-side, and sends them as
68//! encrypted RPC parameters (deterministic and randomized). The remaining
69//! temporal and fixed-width types are not yet supported and return an error
70//! rather than sending plaintext.
71//!
72//! ## Security Model
73//!
74//! - **Client-only decryption**: SQL Server never sees plaintext data
75//! - **DBA protection**: Even database administrators cannot read encrypted data
76//! - **Key separation**: CMK stays in secure key store, never transmitted
77
78use std::collections::HashMap;
79
80use mssql_auth::KeyStoreProvider;
81use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
82
83#[cfg(feature = "always-encrypted")]
84use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
85#[cfg(feature = "always-encrypted")]
86use mssql_types::SqlValue;
87#[cfg(feature = "always-encrypted")]
88use std::sync::Arc;
89
90#[cfg(feature = "always-encrypted")]
91use crate::{Error, row::Row, stream::ResultSet};
92#[cfg(feature = "always-encrypted")]
93use tds_protocol::crypto::CekValue;
94
95/// Configuration for Always Encrypted feature.
96#[derive(Default)]
97pub struct EncryptionConfig {
98    /// Whether encryption is enabled.
99    pub enabled: bool,
100    /// Registered key store providers.
101    providers: Vec<Box<dyn KeyStoreProvider>>,
102    /// Whether to cache decrypted CEKs for performance.
103    pub cache_ceks: bool,
104}
105
106impl EncryptionConfig {
107    /// Create a new encryption configuration (disabled by default).
108    #[must_use]
109    pub fn new() -> Self {
110        Self {
111            enabled: true,
112            providers: Vec::new(),
113            cache_ceks: true,
114        }
115    }
116
117    /// Register a key store provider.
118    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
119        self.providers.push(Box::new(provider));
120    }
121
122    /// Builder method to add a key store provider.
123    #[must_use]
124    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
125        self.register_provider(provider);
126        self
127    }
128
129    /// Enable or disable CEK caching.
130    #[must_use]
131    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
132        self.cache_ceks = enabled;
133        self
134    }
135
136    /// Get a provider by name.
137    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
138        self.providers
139            .iter()
140            .find(|p| p.provider_name() == name)
141            .map(|p| p.as_ref())
142    }
143
144    /// Check if encryption is ready (enabled and has providers).
145    #[must_use]
146    pub fn is_ready(&self) -> bool {
147        self.enabled && !self.providers.is_empty()
148    }
149}
150
151impl std::fmt::Debug for EncryptionConfig {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.debug_struct("EncryptionConfig")
154            .field("enabled", &self.enabled)
155            .field("provider_count", &self.providers.len())
156            .field("cache_ceks", &self.cache_ceks)
157            .finish()
158    }
159}
160
161/// Runtime context for encryption operations.
162///
163/// This is the active encryption state for a connected client,
164/// including resolved CEKs and encryptors.
165///
166/// The context holds an `Arc<EncryptionConfig>` so providers remain accessible
167/// across connection retries/redirects where the `Config` (and its inner
168/// encryption config Arc) gets cloned multiple times.
169#[cfg(feature = "always-encrypted")]
170pub struct EncryptionContext {
171    /// Shared handle on the user-supplied configuration. Providers are looked
172    /// up by name through this reference, so an arbitrary number of `Arc`
173    /// clones do not lose access to them.
174    config: std::sync::Arc<EncryptionConfig>,
175    /// Cache for decrypted CEKs.
176    cek_cache: CekCache,
177    /// Whether caching is enabled.
178    cache_enabled: bool,
179}
180
181#[cfg(feature = "always-encrypted")]
182impl EncryptionContext {
183    /// Create a new encryption context from an Arc-wrapped configuration.
184    ///
185    /// The Arc is retained by the context so provider lookups continue to
186    /// work for the lifetime of the client — regardless of how many times
187    /// the outer `Config` has been cloned for retry/redirect handling.
188    pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
189        let cache_enabled = config.cache_ceks;
190        Self {
191            config,
192            cek_cache: CekCache::new(),
193            cache_enabled,
194        }
195    }
196
197    /// Create a new encryption context from configuration.
198    pub fn new(config: EncryptionConfig) -> Self {
199        Self::from_arc(std::sync::Arc::new(config))
200    }
201
202    /// Get or decrypt a CEK for a column.
203    ///
204    /// This handles the CEK caching and decryption logic:
205    /// 1. Check cache for existing encryptor
206    /// 2. If not cached, decrypt CEK using the appropriate key store
207    /// 3. Create and cache the encryptor
208    pub async fn get_encryptor(
209        &self,
210        cek_entry: &CekTableEntry,
211    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
212        let cache_key = CekCacheKey::new(
213            cek_entry.database_id,
214            cek_entry.cek_id,
215            cek_entry.cek_version,
216        );
217
218        // Check cache first
219        if self.cache_enabled {
220            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
221                return Ok(encryptor);
222            }
223        }
224
225        // Get the primary CEK value
226        let cek_value = cek_entry
227            .primary_value()
228            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
229
230        // Find the appropriate key store provider via the shared config
231        let provider = self
232            .config
233            .get_provider(&cek_value.key_store_provider_name)
234            .ok_or_else(|| {
235                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
236            })?;
237
238        // Decrypt the CEK
239        let decrypted_cek = provider
240            .decrypt_cek(
241                &cek_value.cmk_path,
242                &cek_value.encryption_algorithm,
243                &cek_value.encrypted_value,
244            )
245            .await?;
246
247        // Create encryptor and cache it
248        if self.cache_enabled {
249            self.cek_cache.insert(cache_key, decrypted_cek)
250        } else {
251            // Create encryptor without caching
252            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
253        }
254    }
255
256    /// Encrypt a value for a column.
257    ///
258    /// # Arguments
259    ///
260    /// * `plaintext` - The plaintext value to encrypt
261    /// * `cek_entry` - The CEK table entry for this column
262    /// * `encryption_type` - Deterministic or randomized encryption
263    pub async fn encrypt_value(
264        &self,
265        plaintext: &[u8],
266        cek_entry: &CekTableEntry,
267        encryption_type: EncryptionTypeWire,
268    ) -> Result<Vec<u8>, EncryptionError> {
269        let encryptor = self.get_encryptor(cek_entry).await?;
270
271        let enc_type = match encryption_type {
272            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
273            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
274            _ => {
275                return Err(EncryptionError::UnsupportedOperation(format!(
276                    "unsupported encryption type: {encryption_type:?}"
277                )));
278            }
279        };
280
281        encryptor.encrypt(plaintext, enc_type)
282    }
283
284    /// Decrypt a value from an encrypted column.
285    ///
286    /// # Arguments
287    ///
288    /// * `ciphertext` - The encrypted value
289    /// * `cek_entry` - The CEK table entry for this column
290    pub async fn decrypt_value(
291        &self,
292        ciphertext: &[u8],
293        cek_entry: &CekTableEntry,
294    ) -> Result<Vec<u8>, EncryptionError> {
295        let encryptor = self.get_encryptor(cek_entry).await?;
296        encryptor.decrypt(ciphertext)
297    }
298
299    /// Clear the CEK cache.
300    ///
301    /// Call this when keys may have been rotated.
302    pub fn clear_cache(&self) {
303        self.cek_cache.clear();
304    }
305
306    /// Check if a provider is registered.
307    pub fn has_provider(&self, name: &str) -> bool {
308        self.config.get_provider(name).is_some()
309    }
310}
311
312#[cfg(feature = "always-encrypted")]
313impl std::fmt::Debug for EncryptionContext {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("EncryptionContext")
316            .field("provider_count", &self.config.providers.len())
317            .field("cache_entries", &self.cek_cache.len())
318            .field("cache_enabled", &self.cache_enabled)
319            .finish()
320    }
321}
322
323/// Column encryption metadata for a result set.
324///
325/// This combines the CEK table with per-column crypto metadata,
326/// providing all information needed to decrypt result columns.
327#[derive(Debug, Clone)]
328pub struct ResultSetEncryptionInfo {
329    /// CEK table for this result set.
330    pub cek_table: CekTable,
331    /// Crypto metadata for each column (index matches column ordinal).
332    pub column_crypto: Vec<Option<CryptoMetadata>>,
333}
334
335impl ResultSetEncryptionInfo {
336    /// Create encryption info for a result set.
337    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
338        Self {
339            cek_table,
340            column_crypto: vec![None; column_count],
341        }
342    }
343
344    /// Set crypto metadata for a column.
345    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
346        if ordinal < self.column_crypto.len() {
347            self.column_crypto[ordinal] = Some(metadata);
348        }
349    }
350
351    /// Get the CEK entry for a column.
352    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
353        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
354        self.cek_table.get(crypto.cek_table_ordinal)
355    }
356
357    /// Check if a column is encrypted.
358    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
359        self.column_crypto
360            .get(ordinal)
361            .map(|c| c.is_some())
362            .unwrap_or(false)
363    }
364
365    /// Get the encryption type for a column.
366    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
367        self.column_crypto
368            .get(ordinal)?
369            .as_ref()
370            .map(|c| c.encryption_type)
371    }
372}
373
374/// Parameter encryption metadata for a query.
375///
376/// This is returned by `sp_describe_parameter_encryption` and describes
377/// how each parameter should be encrypted.
378#[derive(Debug, Clone)]
379pub struct ParameterEncryptionInfo {
380    /// CEK table for parameters.
381    pub cek_table: CekTable,
382    /// Mapping from parameter name to crypto metadata.
383    pub parameters: HashMap<String, ParameterCryptoInfo>,
384}
385
386impl ParameterEncryptionInfo {
387    /// Create empty parameter encryption info.
388    pub fn new() -> Self {
389        Self {
390            cek_table: CekTable::new(),
391            parameters: HashMap::new(),
392        }
393    }
394
395    /// Add encryption info for a parameter.
396    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
397        self.parameters.insert(name, info);
398    }
399
400    /// Get encryption info for a parameter.
401    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
402        self.parameters.get(name)
403    }
404
405    /// Check if a parameter needs encryption.
406    pub fn needs_encryption(&self, name: &str) -> bool {
407        self.parameters.contains_key(name)
408    }
409}
410
411impl Default for ParameterEncryptionInfo {
412    fn default() -> Self {
413        Self::new()
414    }
415}
416
417/// Encryption directive for a single parameter, parsed from result set 2 of
418/// `sp_describe_parameter_encryption`.
419#[derive(Debug, Clone)]
420pub struct ParameterCryptoInfo {
421    /// 0-based index into [`ParameterEncryptionInfo::cek_table`].
422    ///
423    /// The server reports a (often 1-based) key ordinal; the parser translates
424    /// it to this positional index so `cek_table.get(cek_ordinal)` resolves the
425    /// entry directly.
426    pub cek_ordinal: u16,
427    /// Encryption type (deterministic or randomized).
428    pub encryption_type: EncryptionTypeWire,
429    /// Encryption algorithm ID (2 = AEAD_AES_256_CBC_HMAC_SHA256).
430    pub algorithm_id: u8,
431    /// Normalization rule version applied to the plaintext before encryption.
432    pub normalization_rule_version: u8,
433}
434
435impl ParameterCryptoInfo {
436    /// Create new parameter crypto info.
437    pub fn new(
438        cek_ordinal: u16,
439        encryption_type: EncryptionTypeWire,
440        algorithm_id: u8,
441        normalization_rule_version: u8,
442    ) -> Self {
443        Self {
444            cek_ordinal,
445            encryption_type,
446            algorithm_id,
447            normalization_rule_version,
448        }
449    }
450}
451
452/// Parsing of the two result sets returned by `sp_describe_parameter_encryption`.
453///
454/// Result set 1 is the CEK table (one row per CMK-wrapping of each CEK); result
455/// set 2 lists, per parameter, how the server expects it encrypted. The column
456/// layout was captured against a live server (SQL Server 2016+): the first nine
457/// RS1 columns are stable across versions; SQL Server 2019+ append two enclave
458/// columns (`is_requested_by_enclave`, `column_master_key_signature`) which this
459/// non-enclave path ignores. Columns are read positionally to match the
460/// `Microsoft.Data.SqlClient` ordinals.
461#[cfg(feature = "always-encrypted")]
462impl ParameterEncryptionInfo {
463    /// Minimum RS1 column count (SQL Server 2017 returns exactly this; 2019+
464    /// return more, with the extra columns appended after these).
465    const RS1_MIN_COLS: usize = 9;
466    /// RS2 column count, stable across supported versions.
467    const RS2_MIN_COLS: usize = 6;
468
469    /// Parse `sp_describe_parameter_encryption` output into encryption metadata.
470    ///
471    /// `result_sets` must be the `ProcedureResult::result_sets` from that RPC.
472    /// Plaintext parameters (encryption type 0) are omitted from the result.
473    pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
474        if result_sets.len() < 2 {
475            return Err(Error::Protocol(format!(
476                "sp_describe_parameter_encryption returned {} result set(s), expected 2",
477                result_sets.len()
478            )));
479        }
480
481        // --- Result set 1: CEK table ---
482        let rs1_cols = result_sets[0].columns().len();
483        if rs1_cols < Self::RS1_MIN_COLS {
484            return Err(Error::Protocol(format!(
485                "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
486                Self::RS1_MIN_COLS
487            )));
488        }
489        let rs1_rows = result_sets[0].collect_all()?;
490
491        let mut entries: Vec<CekTableEntry> = Vec::new();
492        // Server-assigned key ordinal -> positional index into `entries`.
493        let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
494
495        for row in &rs1_rows {
496            let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
497            let value = CekValue {
498                encrypted_value: describe_varbinary(
499                    row,
500                    5,
501                    "column_encryption_key_encrypted_value",
502                )?,
503                key_store_provider_name: describe_nvarchar(
504                    row,
505                    6,
506                    "column_master_key_store_provider_name",
507                )?,
508                cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
509                encryption_algorithm: describe_nvarchar(
510                    row,
511                    8,
512                    "column_encryption_key_encryption_algorithm_name",
513                )?,
514            };
515
516            if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
517                // Another CMK-wrapping of an already-seen CEK (key rotation).
518                entries[idx as usize].values.push(value);
519            } else {
520                let idx = u16::try_from(entries.len()).map_err(|_| {
521                    Error::Protocol(
522                        "sp_describe_parameter_encryption returned too many CEKs".into(),
523                    )
524                })?;
525                ordinal_to_index.insert(key_ordinal, idx);
526                entries.push(CekTableEntry {
527                    database_id: describe_int(row, 1, "database_id")? as u32,
528                    cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
529                    cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
530                    cek_md_version: describe_md_version(row, 4)?,
531                    values: vec![value],
532                });
533            }
534        }
535        let cek_table = CekTable { entries };
536
537        // --- Result set 2: per-parameter directives ---
538        let rs2_cols = result_sets[1].columns().len();
539        if rs2_cols < Self::RS2_MIN_COLS {
540            return Err(Error::Protocol(format!(
541                "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
542                Self::RS2_MIN_COLS
543            )));
544        }
545        let rs2_rows = result_sets[1].collect_all()?;
546
547        let mut parameters = HashMap::new();
548        for row in &rs2_rows {
549            let name = describe_nvarchar(row, 1, "parameter_name")?;
550            let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
551            // 0 = the server determined this parameter needs no encryption.
552            if encryption_type_byte == 0 {
553                continue;
554            }
555            let encryption_type =
556                EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
557                    Error::Protocol(format!(
558                        "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
559                    ))
560                })?;
561            let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
562            let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
563            let normalization_rule_version =
564                describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
565
566            let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
567                Error::Protocol(format!(
568                    "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
569                ))
570            })?;
571
572            parameters.insert(
573                name,
574                ParameterCryptoInfo {
575                    cek_ordinal,
576                    encryption_type,
577                    algorithm_id,
578                    normalization_rule_version,
579                },
580            );
581        }
582
583        Ok(Self {
584            cek_table,
585            parameters,
586        })
587    }
588}
589
590/// Read an `int` describe column, erroring if it is absent or a different type.
591#[cfg(feature = "always-encrypted")]
592fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
593    match row.get_raw(idx) {
594        Some(SqlValue::Int(v)) => Ok(v),
595        other => Err(describe_type_error(col, idx, "int", other.as_ref())),
596    }
597}
598
599/// Read a `tinyint` describe column.
600#[cfg(feature = "always-encrypted")]
601fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
602    match row.get_raw(idx) {
603        Some(SqlValue::TinyInt(v)) => Ok(v),
604        other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
605    }
606}
607
608/// Read an `nvarchar` describe column.
609#[cfg(feature = "always-encrypted")]
610fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
611    match row.get_raw(idx) {
612        Some(SqlValue::String(v)) => Ok(v),
613        other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
614    }
615}
616
617/// Read a `varbinary` describe column.
618#[cfg(feature = "always-encrypted")]
619fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
620    match row.get_raw(idx) {
621        Some(SqlValue::Binary(v)) => Ok(v),
622        other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
623    }
624}
625
626/// Read the `binary(8)` metadata-version column as a little-endian `u64`.
627#[cfg(feature = "always-encrypted")]
628fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
629    match row.get_raw(idx) {
630        Some(SqlValue::Binary(b)) if b.len() == 8 => {
631            let mut bytes = [0u8; 8];
632            bytes.copy_from_slice(&b[..8]);
633            Ok(u64::from_le_bytes(bytes))
634        }
635        other => Err(describe_type_error(
636            "column_encryption_key_metadata_version",
637            idx,
638            "binary(8)",
639            other.as_ref(),
640        )),
641    }
642}
643
644/// Build a uniform error for an unexpected describe-column type.
645#[cfg(feature = "always-encrypted")]
646fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
647    let got = got.map_or("missing", SqlValue::type_name);
648    Error::Protocol(format!(
649        "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
650    ))
651}
652
653/// Normalize a parameter value to the plaintext byte form Always Encrypted
654/// encrypts — SQL Server's "normalized" form for the value's type. The result
655/// is the plaintext input to [`EncryptionContext::encrypt_value`].
656///
657/// Normalization is type-specific and is **not** the regular TDS wire encoding:
658/// e.g. INT normalizes to 8 little-endian bytes (not 4), and strings/binaries
659/// carry no length prefix. These layouts are validated byte-for-byte against
660/// Microsoft.Data.SqlClient (see the `ae_normalization` tests). Only the types
661/// supported so far are handled; others return `UnsupportedOperation`.
662#[cfg(feature = "always-encrypted")]
663pub fn normalize_for_encryption(value: &SqlValue) -> Result<Vec<u8>, EncryptionError> {
664    match value {
665        // All integer types AND bit normalize to 8-byte little-endian (the value
666        // widened to i64). Validated against .NET: tinyint/smallint are 8 bytes,
667        // not their native 1/2 — a spec-reading would get this wrong.
668        SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
669        SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
670        SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
671        SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
672        SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
673        // REAL/FLOAT: the IEEE-754 bits, little-endian (4 and 8 bytes).
674        SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
675        SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
676        // NVARCHAR: UTF-16LE code units, no length prefix.
677        SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
678        // VARBINARY: the raw bytes, no length prefix.
679        SqlValue::Binary(b) => Ok(b.to_vec()),
680        // UNIQUEIDENTIFIER: SQL Server's 16-byte mixed-endian GUID order (first
681        // three groups byte-reversed from the RFC layout, last 8 as-is).
682        #[cfg(feature = "uuid")]
683        SqlValue::Uuid(u) => {
684            let b = u.as_bytes();
685            Ok(vec![
686                b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
687                b[13], b[14], b[15],
688            ])
689        }
690        // DATE: 3-byte little-endian count of days since 0001-01-01.
691        // `num_days_from_ce` counts from day 1, so subtract 1.
692        #[cfg(feature = "chrono")]
693        SqlValue::Date(d) => {
694            use chrono::Datelike;
695            let days = (d.num_days_from_ce() - 1) as u32;
696            Ok(days.to_le_bytes()[..3].to_vec())
697        }
698        // DECIMAL/NUMERIC: 1 sign byte (0 negative, 1 positive) + 16-byte
699        // little-endian unscaled magnitude. Uses the value's own scale.
700        #[cfg(feature = "decimal")]
701        SqlValue::Decimal(d) => {
702            let mut out = Vec::with_capacity(17);
703            out.push(u8::from(!d.is_sign_negative()));
704            out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
705            Ok(out)
706        }
707        // MONEY and SMALLMONEY both normalize to the 8-byte MONEY form: the
708        // value scaled by 10_000 as an i64, high 32 bits then low 32 bits.
709        #[cfg(feature = "decimal")]
710        SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
711            let cents = money_cents(d)?;
712            let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
713            out.extend_from_slice(&(cents as u32).to_le_bytes());
714            Ok(out)
715        }
716        other => Err(EncryptionError::UnsupportedOperation(format!(
717            "Always Encrypted parameter encryption is not yet implemented for {}",
718            other.type_name()
719        ))),
720    }
721}
722
723/// The MONEY fixed-point value (`value * 10_000`) as an `i64`, rounding excess
724/// precision toward zero. Used by both MONEY and SMALLMONEY normalization.
725#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
726fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
727    let mantissa = value.mantissa();
728    let scale = value.scale();
729    let cents: i128 = if scale <= 4 {
730        mantissa
731            .checked_mul(10_i128.pow(4 - scale))
732            .ok_or_else(|| {
733                EncryptionError::UnsupportedOperation("MONEY value out of range".into())
734            })?
735    } else {
736        mantissa / 10_i128.pow(scale - 4)
737    };
738    i64::try_from(cents)
739        .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
740}
741
742#[cfg(test)]
743#[allow(clippy::unwrap_used, clippy::expect_used)]
744mod tests {
745    use super::*;
746
747    /// Reference ciphertexts captured from a live deterministic Always Encrypted
748    /// INSERT via Microsoft.Data.SqlClient 5.2.2. Encrypting our normalization
749    /// with the same CEK must reproduce them byte-for-byte — proving the
750    /// normalized layout matches the real .NET client (notably INT -> 8 LE bytes,
751    /// which is the layout a naive implementation would get wrong).
752    #[cfg(feature = "always-encrypted")]
753    #[test]
754    fn ae_normalization_matches_dotnet() {
755        use bytes::Bytes;
756
757        fn unhex(s: &str) -> Vec<u8> {
758            (0..s.len())
759                .step_by(2)
760                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
761                .collect()
762        }
763
764        let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
765        let enc = AeadEncryptor::new(&cek).unwrap();
766
767        for (value, reference) in [
768            (
769                SqlValue::Int(42),
770                "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
771            ),
772            (
773                SqlValue::String("Ada".to_string()),
774                "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
775            ),
776            (
777                SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
778                "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
779            ),
780        ] {
781            let norm = normalize_for_encryption(&value).unwrap();
782            let cipher = enc
783                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
784                .unwrap();
785            assert_eq!(
786                cipher,
787                unhex(reference),
788                "ciphertext for {} must match Microsoft.Data.SqlClient",
789                value.type_name()
790            );
791        }
792    }
793
794    /// `normalize_for_encryption` rejects values it has no normalization for
795    /// rather than silently producing wrong bytes. NULL is never normalized
796    /// (it is handled as a NULL parameter upstream), so it exercises the
797    /// catch-all rejection arm and stays unsupported as more types are added.
798    #[cfg(feature = "always-encrypted")]
799    #[test]
800    fn ae_normalization_rejects_unnormalizable_value() {
801        assert!(normalize_for_encryption(&SqlValue::Null).is_err());
802    }
803
804    /// Numeric-scalar normalization, validated byte-for-byte against
805    /// Microsoft.Data.SqlClient (same method as [`ae_normalization_matches_dotnet`],
806    /// captured with a fresh CEK). This is the interop guarantee: a value the
807    /// driver encrypts is the value .NET would encrypt. Notable: every integer
808    /// width and bit normalize to 8 bytes, real to 4, float to 8.
809    #[cfg(feature = "always-encrypted")]
810    #[test]
811    fn ae_normalization_matches_dotnet_numeric() {
812        fn unhex(s: &str) -> Vec<u8> {
813            (0..s.len())
814                .step_by(2)
815                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
816                .collect()
817        }
818
819        let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
820        let enc = AeadEncryptor::new(&cek).unwrap();
821
822        for (value, reference) in [
823            (
824                SqlValue::BigInt(0x0102030405060708),
825                "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
826            ),
827            (
828                SqlValue::SmallInt(258),
829                "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
830            ),
831            (
832                SqlValue::TinyInt(200),
833                "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
834            ),
835            (
836                SqlValue::Bool(true),
837                "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
838            ),
839            (
840                SqlValue::Float(3.5),
841                "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
842            ),
843            (
844                SqlValue::Double(3.5),
845                "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
846            ),
847        ] {
848            let norm = normalize_for_encryption(&value).unwrap();
849            let cipher = enc
850                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
851                .unwrap();
852            assert_eq!(
853                cipher,
854                unhex(reference),
855                "ciphertext for {} must match Microsoft.Data.SqlClient",
856                value.type_name()
857            );
858        }
859    }
860
861    /// UUID and DATE normalization, validated byte-for-byte against
862    /// Microsoft.Data.SqlClient: uuid uses SQL Server's mixed-endian GUID byte
863    /// order, date is a 3-byte little-endian day count since 0001-01-01.
864    #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
865    #[test]
866    fn ae_normalization_matches_dotnet_uuid_date() {
867        fn unhex(s: &str) -> Vec<u8> {
868            (0..s.len())
869                .step_by(2)
870                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
871                .collect()
872        }
873
874        let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
875        let enc = AeadEncryptor::new(&cek).unwrap();
876
877        for (value, reference) in [
878            (
879                SqlValue::Uuid(
880                    uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
881                ),
882                "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
883            ),
884            (
885                SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
886                "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
887            ),
888        ] {
889            let norm = normalize_for_encryption(&value).unwrap();
890            let cipher = enc
891                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
892                .unwrap();
893            assert_eq!(
894                cipher,
895                unhex(reference),
896                "ciphertext for {} must match Microsoft.Data.SqlClient",
897                value.type_name()
898            );
899        }
900    }
901
902    /// DECIMAL and MONEY/SMALLMONEY normalization, validated byte-for-byte
903    /// against Microsoft.Data.SqlClient: decimal is a sign byte plus a 16-byte
904    /// little-endian unscaled magnitude; money and smallmoney both use the
905    /// 8-byte MONEY form (value × 10_000, high then low 32 bits).
906    #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
907    #[test]
908    fn ae_normalization_matches_dotnet_decimal_money() {
909        fn unhex(s: &str) -> Vec<u8> {
910            (0..s.len())
911                .step_by(2)
912                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
913                .collect()
914        }
915
916        let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
917        let enc = AeadEncryptor::new(&cek).unwrap();
918        let dec = rust_decimal::Decimal::new(123_456_789, 4); // 12345.6789
919        let money = rust_decimal::Decimal::new(123_400, 4); // 12.3400
920
921        for (value, reference) in [
922            (
923                SqlValue::Decimal(dec),
924                "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
925            ),
926            (
927                SqlValue::Money(money),
928                "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
929            ),
930            (
931                SqlValue::SmallMoney(money),
932                "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
933            ),
934        ] {
935            let norm = normalize_for_encryption(&value).unwrap();
936            let cipher = enc
937                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
938                .unwrap();
939            assert_eq!(
940                cipher,
941                unhex(reference),
942                "ciphertext for {} must match Microsoft.Data.SqlClient",
943                value.type_name()
944            );
945        }
946    }
947
948    #[test]
949    fn test_encryption_config_defaults() {
950        let config = EncryptionConfig::new();
951        assert!(config.enabled);
952        assert!(config.cache_ceks);
953        assert!(!config.is_ready()); // No providers
954    }
955
956    #[test]
957    fn test_result_set_encryption_info() {
958        let cek_table = CekTable::new();
959        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
960
961        assert!(!info.is_column_encrypted(0));
962        assert!(!info.is_column_encrypted(1));
963        assert!(!info.is_column_encrypted(2));
964
965        let metadata = CryptoMetadata {
966            cek_table_ordinal: 0,
967            base_user_type: 0,
968            base_col_type: 0x26,
969            base_type_info: tds_protocol::token::TypeInfo::default(),
970            algorithm_id: 2,
971            encryption_type: EncryptionTypeWire::Deterministic,
972            normalization_version: 1,
973        };
974
975        info.set_column_crypto(1, metadata);
976        assert!(!info.is_column_encrypted(0));
977        assert!(info.is_column_encrypted(1));
978        assert!(!info.is_column_encrypted(2));
979
980        assert_eq!(
981            info.get_encryption_type(1),
982            Some(EncryptionTypeWire::Deterministic)
983        );
984    }
985
986    #[test]
987    fn test_parameter_encryption_info() {
988        let mut info = ParameterEncryptionInfo::new();
989
990        assert!(!info.needs_encryption("@p1"));
991
992        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1);
993        info.add_parameter("@p1".to_string(), crypto);
994
995        assert!(info.needs_encryption("@p1"));
996        assert!(!info.needs_encryption("@p2"));
997
998        let param = info.get_parameter("@p1").unwrap();
999        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
1000    }
1001
1002    /// Parse synthetic `sp_describe_parameter_encryption` result sets that mirror
1003    /// the live wire shape (captured in `.tmp/ae-3a2-describe-schema.md`). The
1004    /// column *order* is validated separately by the live test; this exercises
1005    /// the logic the live single-CEK/single-CMK case cannot: grouping multiple
1006    /// CMK-wrappings under one CEK, translating the server's (1-based) key
1007    /// ordinal to a positional index, little-endian `binary(8)` md-version
1008    /// decode, and skipping plaintext parameters.
1009    #[cfg(feature = "always-encrypted")]
1010    #[test]
1011    fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1012        use crate::row::{Column, Row};
1013        use crate::stream::ResultSet;
1014        use bytes::Bytes;
1015
1016        fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1017            let cols: Vec<Column> = (0..n_cols)
1018                .map(|i| Column::new(format!("c{i}"), i, "x"))
1019                .collect();
1020            let rows = rows
1021                .into_iter()
1022                .map(|vals| Row::from_values(cols.clone(), vals))
1023                .collect();
1024            ResultSet::new(cols, rows)
1025        }
1026
1027        let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); // -> 1
1028        let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); // -> 255
1029
1030        // RS1: CEK ordinal 1 wrapped by two CMKs (rotation), plus CEK ordinal 2.
1031        let rs1 = rs(
1032            9,
1033            vec![
1034                vec![
1035                    SqlValue::Int(1),
1036                    SqlValue::Int(7),
1037                    SqlValue::Int(56),
1038                    SqlValue::Int(1),
1039                    SqlValue::Binary(mdv1.clone()),
1040                    SqlValue::Binary(Bytes::from_static(b"env-a")),
1041                    SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1042                    SqlValue::String("path-a".into()),
1043                    SqlValue::String("RSA_OAEP".into()),
1044                ],
1045                vec![
1046                    SqlValue::Int(1),
1047                    SqlValue::Int(7),
1048                    SqlValue::Int(56),
1049                    SqlValue::Int(1),
1050                    SqlValue::Binary(mdv1),
1051                    SqlValue::Binary(Bytes::from_static(b"env-a2")),
1052                    SqlValue::String("PROV_2".into()),
1053                    SqlValue::String("path-a2".into()),
1054                    SqlValue::String("RSA_OAEP".into()),
1055                ],
1056                vec![
1057                    SqlValue::Int(2),
1058                    SqlValue::Int(7),
1059                    SqlValue::Int(57),
1060                    SqlValue::Int(1),
1061                    SqlValue::Binary(mdv2),
1062                    SqlValue::Binary(Bytes::from_static(b"env-b")),
1063                    SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1064                    SqlValue::String("path-b".into()),
1065                    SqlValue::String("RSA_OAEP".into()),
1066                ],
1067            ],
1068        );
1069
1070        // RS2: @det on CEK ordinal 1, @rand on CEK ordinal 2, @plain plaintext.
1071        let rs2 = rs(
1072            6,
1073            vec![
1074                vec![
1075                    SqlValue::Int(1),
1076                    SqlValue::String("@det".into()),
1077                    SqlValue::TinyInt(2),
1078                    SqlValue::TinyInt(1),
1079                    SqlValue::Int(1),
1080                    SqlValue::TinyInt(1),
1081                ],
1082                vec![
1083                    SqlValue::Int(2),
1084                    SqlValue::String("@rand".into()),
1085                    SqlValue::TinyInt(2),
1086                    SqlValue::TinyInt(2),
1087                    SqlValue::Int(2),
1088                    SqlValue::TinyInt(1),
1089                ],
1090                vec![
1091                    SqlValue::Int(3),
1092                    SqlValue::String("@plain".into()),
1093                    SqlValue::TinyInt(0),
1094                    SqlValue::TinyInt(0),
1095                    SqlValue::Int(0),
1096                    SqlValue::TinyInt(0),
1097                ],
1098            ],
1099        );
1100
1101        let mut sets = vec![rs1, rs2];
1102        let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1103
1104        assert_eq!(info.cek_table.len(), 2);
1105        let e0 = info.cek_table.get(0).unwrap();
1106        assert_eq!(e0.cek_id, 56);
1107        assert_eq!(e0.cek_md_version, 1);
1108        assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1109        assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1110        assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1111        let e1 = info.cek_table.get(1).unwrap();
1112        assert_eq!(e1.cek_id, 57);
1113        assert_eq!(e1.cek_md_version, 255);
1114
1115        let det = info.get_parameter("@det").unwrap();
1116        assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1117        assert_eq!(det.algorithm_id, 2);
1118        assert_eq!(det.normalization_rule_version, 1);
1119        assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1120
1121        let rand = info.get_parameter("@rand").unwrap();
1122        assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1123        assert_eq!(
1124            rand.cek_ordinal, 1,
1125            "server ordinal 2 -> positional index 1"
1126        );
1127
1128        assert!(!info.needs_encryption("@plain"));
1129        assert_eq!(info.parameters.len(), 2);
1130    }
1131
1132    /// A truncated response (fewer than two result sets) must be rejected, not
1133    /// silently treated as "no parameters need encryption".
1134    #[cfg(feature = "always-encrypted")]
1135    #[test]
1136    fn parse_describe_result_sets_rejects_missing_result_set() {
1137        use crate::row::{Column, Row};
1138        use crate::stream::ResultSet;
1139
1140        let cols: Vec<Column> = (0..9)
1141            .map(|i| Column::new(format!("c{i}"), i, "x"))
1142            .collect();
1143        let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1144        assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1145    }
1146}