Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,no_run
23//! # async fn with_always_encrypted() -> Result<(), Box<dyn std::error::Error>> {
24//! # #[cfg(feature = "always-encrypted")]
25//! # {
26//! # let conn_str = "Server=localhost;Database=db;Encrypt=strict;Column Encryption Setting=Enabled";
27//! use mssql_client::{Client, Config, EncryptionConfig};
28//! use mssql_auth::InMemoryKeyStore;
29//!
30//! // Register a key-store provider, then attach it to the connection config.
31//! let key_store = InMemoryKeyStore::new();
32//! let encryption_config = EncryptionConfig::new().with_provider(key_store);
33//!
34//! let config = Config::from_connection_string(conn_str)?
35//!     .with_column_encryption(encryption_config);
36//!
37//! let _client = Client::connect(config).await?;
38//! # }
39//! # Ok(())
40//! # }
41//! ```
42//!
43//! Equivalently, set `Column Encryption Setting=Enabled` in the connection
44//! string. Production-ready providers ship in `mssql-auth`: `InMemoryKeyStore`
45//! (dev/test), `AzureKeyVaultProvider` (`azure-identity` feature), and
46//! `WindowsCertStoreProvider` (`sspi-auth`, Windows). Implement
47//! [`mssql_auth::KeyStoreProvider`] for custom key storage. Do **not** substitute
48//! T-SQL `ENCRYPTBYKEY` — the server can see that plaintext, defeating the point.
49//!
50//! ## How decryption works
51//!
52//! 1. Always Encrypted support is negotiated in LOGIN7 (`FEATURE_EXT`).
53//! 2. `ColMetaData` carries [`CryptoMetadata`] and the [`CekTable`]; column
54//!    encryption keys are resolved asynchronously up front (calling the key-store
55//!    providers).
56//! 3. Each encrypted cell is decrypted during row parsing via
57//!    AEAD_AES_256_CBC_HMAC_SHA256, with the HMAC verified before decryption.
58//!
59//! Reads are transparent across `query`, `call_procedure`, the procedure
60//! builder, and multi-result queries. Parameter (write) encryption is wired
61//! into parameterized `query`/`execute` for the common scalar types — `int`,
62//! `tinyint`, `smallint`, `bigint`, `bit`, `real`, `float`, `nvarchar`,
63//! `varbinary`, `uniqueidentifier`, `date`, `money`, `smallmoney`, `decimal`
64//! (via `numeric(value, precision, scale)`), and typed `NULL` (via
65//! `null::<T>()`): with `Column Encryption Setting=Enabled` the
66//! driver describes the parameters (`sp_describe_parameter_encryption`),
67//! encrypts those bound to encrypted columns client-side, and sends them as
68//! encrypted RPC parameters (deterministic and randomized). The temporal and
69//! fixed-width types are supported through typed-parameter wrappers —
70//! `time(v, scale)`, `datetime2(v, scale)`, `datetimeoffset(v, scale)`,
71//! `datetime(v)` (legacy), the `SmallDateTime` wrapper, and `char(v, len)` /
72//! `nchar(v, len)` / `binary(v, len)`. The full scalar/temporal/fixed-width set
73//! is now covered.
74//!
75//! Bind a `decimal` parameter with `numeric(value, precision, scale)`, not a
76//! plain `Decimal`, and a scaled temporal or fixed-width value with its wrapper,
77//! not a bare `NaiveDateTime`/`NaiveTime`/`String`: an encrypted column requires
78//! the declared type — precision/scale/length included — to match the column
79//! exactly, which a bare value can't convey, so the server rejects it with
80//! `Operand type clash` (Msg 206) at the describe step. Encrypted `char`/`nchar`
81//! columns must use a `*_BIN2` collation (a SQL Server requirement for
82//! deterministic encryption of character types); `char` is encoded as
83//! Windows-1252. The scale-7 temporal and the fixed-width forms are validated
84//! byte-for-byte against `Microsoft.Data.SqlClient`; lower temporal scales are
85//! validated by live round-trip (Microsoft's own client defaults temporal
86//! parameters to scale 7, so it can't emit lower-scale forms for a byte-exact
87//! comparison).
88//!
89//! ## Security Model
90//!
91//! - **Client-only decryption**: SQL Server never sees plaintext data
92//! - **DBA protection**: Even database administrators cannot read encrypted data
93//! - **Key separation**: CMK stays in secure key store, never transmitted
94
95use std::collections::HashMap;
96
97use mssql_auth::KeyStoreProvider;
98use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
99
100#[cfg(feature = "always-encrypted")]
101use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
102#[cfg(feature = "always-encrypted")]
103use mssql_types::SqlValue;
104#[cfg(feature = "always-encrypted")]
105use std::sync::Arc;
106
107#[cfg(feature = "always-encrypted")]
108use crate::{Error, row::Row, stream::ResultSet};
109#[cfg(feature = "always-encrypted")]
110use tds_protocol::crypto::CekValue;
111
112/// Configuration for Always Encrypted feature.
113#[derive(Default)]
114pub struct EncryptionConfig {
115    /// Whether encryption is enabled.
116    pub enabled: bool,
117    /// Registered key store providers.
118    providers: Vec<Box<dyn KeyStoreProvider>>,
119    /// Whether to cache decrypted CEKs for performance.
120    pub cache_ceks: bool,
121}
122
123impl EncryptionConfig {
124    /// Create a new encryption configuration (disabled by default).
125    #[must_use]
126    pub fn new() -> Self {
127        Self {
128            enabled: true,
129            providers: Vec::new(),
130            cache_ceks: true,
131        }
132    }
133
134    /// Register a key store provider.
135    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
136        self.providers.push(Box::new(provider));
137    }
138
139    /// Builder method to add a key store provider.
140    #[must_use]
141    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
142        self.register_provider(provider);
143        self
144    }
145
146    /// Enable or disable CEK caching.
147    #[must_use]
148    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
149        self.cache_ceks = enabled;
150        self
151    }
152
153    /// Get a provider by name.
154    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
155        self.providers
156            .iter()
157            .find(|p| p.provider_name() == name)
158            .map(|p| p.as_ref())
159    }
160
161    /// Check if encryption is ready (enabled and has providers).
162    #[must_use]
163    pub fn is_ready(&self) -> bool {
164        self.enabled && !self.providers.is_empty()
165    }
166}
167
168impl std::fmt::Debug for EncryptionConfig {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        f.debug_struct("EncryptionConfig")
171            .field("enabled", &self.enabled)
172            .field("provider_count", &self.providers.len())
173            .field("cache_ceks", &self.cache_ceks)
174            .finish()
175    }
176}
177
178/// Runtime context for encryption operations.
179///
180/// This is the active encryption state for a connected client,
181/// including resolved CEKs and encryptors.
182///
183/// The context holds an `Arc<EncryptionConfig>` so providers remain accessible
184/// across connection retries/redirects where the `Config` (and its inner
185/// encryption config Arc) gets cloned multiple times.
186#[cfg(feature = "always-encrypted")]
187pub struct EncryptionContext {
188    /// Shared handle on the user-supplied configuration. Providers are looked
189    /// up by name through this reference, so an arbitrary number of `Arc`
190    /// clones do not lose access to them.
191    config: std::sync::Arc<EncryptionConfig>,
192    /// Cache for decrypted CEKs.
193    cek_cache: CekCache,
194    /// Whether caching is enabled.
195    cache_enabled: bool,
196}
197
198#[cfg(feature = "always-encrypted")]
199impl EncryptionContext {
200    /// Create a new encryption context from an Arc-wrapped configuration.
201    ///
202    /// The Arc is retained by the context so provider lookups continue to
203    /// work for the lifetime of the client — regardless of how many times
204    /// the outer `Config` has been cloned for retry/redirect handling.
205    pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
206        let cache_enabled = config.cache_ceks;
207        Self {
208            config,
209            cek_cache: CekCache::new(),
210            cache_enabled,
211        }
212    }
213
214    /// Create a new encryption context from configuration.
215    pub fn new(config: EncryptionConfig) -> Self {
216        Self::from_arc(std::sync::Arc::new(config))
217    }
218
219    /// Get or decrypt a CEK for a column.
220    ///
221    /// This handles the CEK caching and decryption logic:
222    /// 1. Check cache for existing encryptor
223    /// 2. If not cached, decrypt CEK using the appropriate key store
224    /// 3. Create and cache the encryptor
225    pub async fn get_encryptor(
226        &self,
227        cek_entry: &CekTableEntry,
228    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
229        let cache_key = CekCacheKey::new(
230            cek_entry.database_id,
231            cek_entry.cek_id,
232            cek_entry.cek_version,
233        );
234
235        // Check cache first
236        if self.cache_enabled {
237            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
238                return Ok(encryptor);
239            }
240        }
241
242        // Get the primary CEK value
243        let cek_value = cek_entry
244            .primary_value()
245            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
246
247        // Find the appropriate key store provider via the shared config
248        let provider = self
249            .config
250            .get_provider(&cek_value.key_store_provider_name)
251            .ok_or_else(|| {
252                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
253            })?;
254
255        // Decrypt the CEK
256        let decrypted_cek = provider
257            .decrypt_cek(
258                &cek_value.cmk_path,
259                &cek_value.encryption_algorithm,
260                &cek_value.encrypted_value,
261            )
262            .await?;
263
264        // Create encryptor and cache it
265        if self.cache_enabled {
266            self.cek_cache.insert(cache_key, decrypted_cek)
267        } else {
268            // Create encryptor without caching
269            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
270        }
271    }
272
273    /// Encrypt a value for a column.
274    ///
275    /// # Arguments
276    ///
277    /// * `plaintext` - The plaintext value to encrypt
278    /// * `cek_entry` - The CEK table entry for this column
279    /// * `encryption_type` - Deterministic or randomized encryption
280    pub async fn encrypt_value(
281        &self,
282        plaintext: &[u8],
283        cek_entry: &CekTableEntry,
284        encryption_type: EncryptionTypeWire,
285    ) -> Result<Vec<u8>, EncryptionError> {
286        let encryptor = self.get_encryptor(cek_entry).await?;
287
288        let enc_type = match encryption_type {
289            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
290            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
291            _ => {
292                return Err(EncryptionError::UnsupportedOperation(format!(
293                    "unsupported encryption type: {encryption_type:?}"
294                )));
295            }
296        };
297
298        encryptor.encrypt(plaintext, enc_type)
299    }
300
301    /// Decrypt a value from an encrypted column.
302    ///
303    /// # Arguments
304    ///
305    /// * `ciphertext` - The encrypted value
306    /// * `cek_entry` - The CEK table entry for this column
307    pub async fn decrypt_value(
308        &self,
309        ciphertext: &[u8],
310        cek_entry: &CekTableEntry,
311    ) -> Result<Vec<u8>, EncryptionError> {
312        let encryptor = self.get_encryptor(cek_entry).await?;
313        encryptor.decrypt(ciphertext)
314    }
315
316    /// Clear the CEK cache.
317    ///
318    /// Call this when keys may have been rotated.
319    pub fn clear_cache(&self) {
320        self.cek_cache.clear();
321    }
322
323    /// Check if a provider is registered.
324    pub fn has_provider(&self, name: &str) -> bool {
325        self.config.get_provider(name).is_some()
326    }
327}
328
329#[cfg(feature = "always-encrypted")]
330impl std::fmt::Debug for EncryptionContext {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("EncryptionContext")
333            .field("provider_count", &self.config.providers.len())
334            .field("cache_entries", &self.cek_cache.len())
335            .field("cache_enabled", &self.cache_enabled)
336            .finish()
337    }
338}
339
340/// Column encryption metadata for a result set.
341///
342/// This combines the CEK table with per-column crypto metadata,
343/// providing all information needed to decrypt result columns.
344#[derive(Debug, Clone)]
345pub struct ResultSetEncryptionInfo {
346    /// CEK table for this result set.
347    pub cek_table: CekTable,
348    /// Crypto metadata for each column (index matches column ordinal).
349    pub column_crypto: Vec<Option<CryptoMetadata>>,
350}
351
352impl ResultSetEncryptionInfo {
353    /// Create encryption info for a result set.
354    pub fn new(cek_table: CekTable, column_count: usize) -> Self {
355        Self {
356            cek_table,
357            column_crypto: vec![None; column_count],
358        }
359    }
360
361    /// Set crypto metadata for a column.
362    pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
363        if ordinal < self.column_crypto.len() {
364            self.column_crypto[ordinal] = Some(metadata);
365        }
366    }
367
368    /// Get the CEK entry for a column.
369    pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
370        let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
371        self.cek_table.get(crypto.cek_table_ordinal)
372    }
373
374    /// Check if a column is encrypted.
375    pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
376        self.column_crypto
377            .get(ordinal)
378            .map(|c| c.is_some())
379            .unwrap_or(false)
380    }
381
382    /// Get the encryption type for a column.
383    pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
384        self.column_crypto
385            .get(ordinal)?
386            .as_ref()
387            .map(|c| c.encryption_type)
388    }
389}
390
391/// Parameter encryption metadata for a query.
392///
393/// This is returned by `sp_describe_parameter_encryption` and describes
394/// how each parameter should be encrypted.
395#[derive(Debug, Clone)]
396pub struct ParameterEncryptionInfo {
397    /// CEK table for parameters.
398    pub cek_table: CekTable,
399    /// Mapping from parameter name to crypto metadata.
400    pub parameters: HashMap<String, ParameterCryptoInfo>,
401}
402
403impl ParameterEncryptionInfo {
404    /// Create empty parameter encryption info.
405    pub fn new() -> Self {
406        Self {
407            cek_table: CekTable::new(),
408            parameters: HashMap::new(),
409        }
410    }
411
412    /// Add encryption info for a parameter.
413    pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
414        self.parameters.insert(name, info);
415    }
416
417    /// Get encryption info for a parameter.
418    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
419        self.parameters.get(name)
420    }
421
422    /// Check if a parameter needs encryption.
423    pub fn needs_encryption(&self, name: &str) -> bool {
424        self.parameters.contains_key(name)
425    }
426}
427
428impl Default for ParameterEncryptionInfo {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434/// Encryption directive for a single parameter, parsed from result set 2 of
435/// `sp_describe_parameter_encryption`.
436#[derive(Debug, Clone)]
437pub struct ParameterCryptoInfo {
438    /// 0-based index into [`ParameterEncryptionInfo::cek_table`].
439    ///
440    /// The server reports a (often 1-based) key ordinal; the parser translates
441    /// it to this positional index so `cek_table.get(cek_ordinal)` resolves the
442    /// entry directly.
443    pub cek_ordinal: u16,
444    /// Encryption type (deterministic or randomized).
445    pub encryption_type: EncryptionTypeWire,
446    /// Encryption algorithm ID (2 = AEAD_AES_256_CBC_HMAC_SHA256).
447    pub algorithm_id: u8,
448    /// Normalization rule version applied to the plaintext before encryption.
449    pub normalization_rule_version: u8,
450}
451
452impl ParameterCryptoInfo {
453    /// Create new parameter crypto info.
454    pub fn new(
455        cek_ordinal: u16,
456        encryption_type: EncryptionTypeWire,
457        algorithm_id: u8,
458        normalization_rule_version: u8,
459    ) -> Self {
460        Self {
461            cek_ordinal,
462            encryption_type,
463            algorithm_id,
464            normalization_rule_version,
465        }
466    }
467}
468
469/// Parsing of the two result sets returned by `sp_describe_parameter_encryption`.
470///
471/// Result set 1 is the CEK table (one row per CMK-wrapping of each CEK); result
472/// set 2 lists, per parameter, how the server expects it encrypted. The column
473/// layout was captured against a live server (SQL Server 2016+): the first nine
474/// RS1 columns are stable across versions; SQL Server 2019+ append two enclave
475/// columns (`is_requested_by_enclave`, `column_master_key_signature`) which this
476/// non-enclave path ignores. Columns are read positionally to match the
477/// `Microsoft.Data.SqlClient` ordinals.
478#[cfg(feature = "always-encrypted")]
479impl ParameterEncryptionInfo {
480    /// Minimum RS1 column count (SQL Server 2017 returns exactly this; 2019+
481    /// return more, with the extra columns appended after these).
482    const RS1_MIN_COLS: usize = 9;
483    /// RS2 column count, stable across supported versions.
484    const RS2_MIN_COLS: usize = 6;
485
486    /// Parse `sp_describe_parameter_encryption` output into encryption metadata.
487    ///
488    /// `result_sets` must be the `ProcedureResult::result_sets` from that RPC.
489    /// Plaintext parameters (encryption type 0) are omitted from the result.
490    pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
491        if result_sets.len() < 2 {
492            return Err(Error::Protocol(format!(
493                "sp_describe_parameter_encryption returned {} result set(s), expected 2",
494                result_sets.len()
495            )));
496        }
497
498        // --- Result set 1: CEK table ---
499        let rs1_cols = result_sets[0].columns().len();
500        if rs1_cols < Self::RS1_MIN_COLS {
501            return Err(Error::Protocol(format!(
502                "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
503                Self::RS1_MIN_COLS
504            )));
505        }
506        let rs1_rows = result_sets[0].collect_all()?;
507
508        let mut entries: Vec<CekTableEntry> = Vec::new();
509        // Server-assigned key ordinal -> positional index into `entries`.
510        let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
511
512        for row in &rs1_rows {
513            let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
514            let value = CekValue {
515                encrypted_value: describe_varbinary(
516                    row,
517                    5,
518                    "column_encryption_key_encrypted_value",
519                )?,
520                key_store_provider_name: describe_nvarchar(
521                    row,
522                    6,
523                    "column_master_key_store_provider_name",
524                )?,
525                cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
526                encryption_algorithm: describe_nvarchar(
527                    row,
528                    8,
529                    "column_encryption_key_encryption_algorithm_name",
530                )?,
531            };
532
533            if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
534                // Another CMK-wrapping of an already-seen CEK (key rotation).
535                entries[idx as usize].values.push(value);
536            } else {
537                let idx = u16::try_from(entries.len()).map_err(|_| {
538                    Error::Protocol(
539                        "sp_describe_parameter_encryption returned too many CEKs".into(),
540                    )
541                })?;
542                ordinal_to_index.insert(key_ordinal, idx);
543                entries.push(CekTableEntry {
544                    database_id: describe_int(row, 1, "database_id")? as u32,
545                    cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
546                    cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
547                    cek_md_version: describe_md_version(row, 4)?,
548                    values: vec![value],
549                });
550            }
551        }
552        let cek_table = CekTable { entries };
553
554        // --- Result set 2: per-parameter directives ---
555        let rs2_cols = result_sets[1].columns().len();
556        if rs2_cols < Self::RS2_MIN_COLS {
557            return Err(Error::Protocol(format!(
558                "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
559                Self::RS2_MIN_COLS
560            )));
561        }
562        let rs2_rows = result_sets[1].collect_all()?;
563
564        let mut parameters = HashMap::new();
565        for row in &rs2_rows {
566            let name = describe_nvarchar(row, 1, "parameter_name")?;
567            let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
568            // 0 = the server determined this parameter needs no encryption.
569            if encryption_type_byte == 0 {
570                continue;
571            }
572            let encryption_type =
573                EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
574                    Error::Protocol(format!(
575                        "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
576                    ))
577                })?;
578            let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
579            let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
580            let normalization_rule_version =
581                describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
582
583            let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
584                Error::Protocol(format!(
585                    "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
586                ))
587            })?;
588
589            parameters.insert(
590                name,
591                ParameterCryptoInfo {
592                    cek_ordinal,
593                    encryption_type,
594                    algorithm_id,
595                    normalization_rule_version,
596                },
597            );
598        }
599
600        Ok(Self {
601            cek_table,
602            parameters,
603        })
604    }
605}
606
607/// Read an `int` describe column, erroring if it is absent or a different type.
608#[cfg(feature = "always-encrypted")]
609fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
610    match row.get_raw(idx) {
611        Some(SqlValue::Int(v)) => Ok(v),
612        other => Err(describe_type_error(col, idx, "int", other.as_ref())),
613    }
614}
615
616/// Read a `tinyint` describe column.
617#[cfg(feature = "always-encrypted")]
618fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
619    match row.get_raw(idx) {
620        Some(SqlValue::TinyInt(v)) => Ok(v),
621        other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
622    }
623}
624
625/// Read an `nvarchar` describe column.
626#[cfg(feature = "always-encrypted")]
627fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
628    match row.get_raw(idx) {
629        Some(SqlValue::String(v)) => Ok(v),
630        other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
631    }
632}
633
634/// Read a `varbinary` describe column.
635#[cfg(feature = "always-encrypted")]
636fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
637    match row.get_raw(idx) {
638        Some(SqlValue::Binary(v)) => Ok(v),
639        other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
640    }
641}
642
643/// Read the `binary(8)` metadata-version column as a little-endian `u64`.
644#[cfg(feature = "always-encrypted")]
645fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
646    match row.get_raw(idx) {
647        Some(SqlValue::Binary(b)) if b.len() == 8 => {
648            let mut bytes = [0u8; 8];
649            bytes.copy_from_slice(&b[..8]);
650            Ok(u64::from_le_bytes(bytes))
651        }
652        other => Err(describe_type_error(
653            "column_encryption_key_metadata_version",
654            idx,
655            "binary(8)",
656            other.as_ref(),
657        )),
658    }
659}
660
661/// Build a uniform error for an unexpected describe-column type.
662#[cfg(feature = "always-encrypted")]
663fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
664    let got = got.map_or("missing", SqlValue::type_name);
665    Error::Protocol(format!(
666        "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
667    ))
668}
669
670/// Normalize a parameter value to the plaintext byte form Always Encrypted
671/// encrypts — SQL Server's "normalized" form for the value's type. The result
672/// is the plaintext input to [`EncryptionContext::encrypt_value`].
673///
674/// Normalization is type-specific and is **not** the regular TDS wire encoding:
675/// e.g. INT normalizes to 8 little-endian bytes (not 4), and strings/binaries
676/// carry no length prefix. These layouts are validated byte-for-byte against
677/// Microsoft.Data.SqlClient (see the `ae_normalization` tests). Only the types
678/// supported so far are handled; others return `UnsupportedOperation`.
679///
680/// Typed temporal parameters (`time`/`datetime2`/`datetimeoffset`/`datetime`)
681/// pass their [`mssql_types::EncryptedParamType`] in `param_type`: their byte
682/// length depends on the column scale, so the value alone is insufficient.
683#[cfg(feature = "always-encrypted")]
684pub fn normalize_for_encryption(
685    value: &SqlValue,
686    param_type: Option<mssql_types::EncryptedParamType>,
687) -> Result<Vec<u8>, EncryptionError> {
688    // CHAR: the value's bytes in the column code page (Windows-1252), unpadded.
689    // NCHAR/BINARY reuse the String/Binary value arms below (UTF-16 / raw).
690    if let (Some(mssql_types::EncryptedParamType::Char { .. }), SqlValue::String(s)) =
691        (param_type, value)
692    {
693        return Ok(encoding_rs::WINDOWS_1252.encode(s).0.into_owned());
694    }
695    // Typed temporal parameters carry the column scale (the encrypted byte
696    // length depends on it), so they're handled from the hint, not the value.
697    #[cfg(feature = "chrono")]
698    {
699        use mssql_types::EncryptedParamType as E;
700        match (param_type, value) {
701            (Some(E::Time { scale }), SqlValue::Time(t)) => return normalize_ae_time(*t, scale),
702            (Some(E::DateTime2 { scale }), SqlValue::DateTime(dt)) => {
703                return normalize_ae_datetime2(*dt, scale);
704            }
705            (Some(E::DateTimeOffset { scale }), SqlValue::DateTimeOffset(dto)) => {
706                return normalize_ae_datetimeoffset(*dto, scale);
707            }
708            (Some(E::DateTime), SqlValue::DateTime(dt)) => {
709                let mut buf = bytes::BytesMut::with_capacity(8);
710                mssql_types::encode::encode_datetime_legacy(*dt, &mut buf);
711                return Ok(buf.to_vec());
712            }
713            _ => {}
714        }
715    }
716    match value {
717        // All integer types AND bit normalize to 8-byte little-endian (the value
718        // widened to i64). Validated against .NET: tinyint/smallint are 8 bytes,
719        // not their native 1/2 — a spec-reading would get this wrong.
720        SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
721        SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
722        SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
723        SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
724        SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
725        // REAL/FLOAT: the IEEE-754 bits, little-endian (4 and 8 bytes).
726        SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
727        SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
728        // NVARCHAR: UTF-16LE code units, no length prefix.
729        SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
730        // VARBINARY: the raw bytes, no length prefix.
731        SqlValue::Binary(b) => Ok(b.to_vec()),
732        // UNIQUEIDENTIFIER: SQL Server's 16-byte mixed-endian GUID order (first
733        // three groups byte-reversed from the RFC layout, last 8 as-is).
734        #[cfg(feature = "uuid")]
735        SqlValue::Uuid(u) => {
736            let b = u.as_bytes();
737            Ok(vec![
738                b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
739                b[13], b[14], b[15],
740            ])
741        }
742        // DATE: 3-byte little-endian count of days since 0001-01-01.
743        // `num_days_from_ce` counts from day 1, so subtract 1.
744        #[cfg(feature = "chrono")]
745        SqlValue::Date(d) => {
746            use chrono::Datelike;
747            let days = (d.num_days_from_ce() - 1) as u32;
748            Ok(days.to_le_bytes()[..3].to_vec())
749        }
750        // DECIMAL/NUMERIC: 1 sign byte (0 negative, 1 positive) + 16-byte
751        // little-endian unscaled magnitude. Uses the value's own scale.
752        #[cfg(feature = "decimal")]
753        SqlValue::Decimal(d) => {
754            let mut out = Vec::with_capacity(17);
755            out.push(u8::from(!d.is_sign_negative()));
756            out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
757            Ok(out)
758        }
759        // MONEY and SMALLMONEY both normalize to the 8-byte MONEY form: the
760        // value scaled by 10_000 as an i64, high 32 bits then low 32 bits.
761        #[cfg(feature = "decimal")]
762        SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
763            let cents = money_cents(d)?;
764            let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
765            out.extend_from_slice(&(cents as u32).to_le_bytes());
766            Ok(out)
767        }
768        // SMALLDATETIME: 2-byte days-since-1900 + 2-byte minutes-since-midnight.
769        // Declared correctly by the `SmallDateTime` wrapper, so no scale hint.
770        #[cfg(feature = "chrono")]
771        SqlValue::SmallDateTime(dt) => {
772            let mut buf = bytes::BytesMut::with_capacity(4);
773            mssql_types::encode::encode_smalldatetime(*dt, &mut buf).map_err(|e| {
774                EncryptionError::UnsupportedOperation(format!("SMALLDATETIME: {e}"))
775            })?;
776            Ok(buf.to_vec())
777        }
778        other => Err(EncryptionError::UnsupportedOperation(format!(
779            "Always Encrypted parameter encryption is not yet implemented for {}",
780            other.type_name()
781        ))),
782    }
783}
784
785/// Days since 0001-01-01 as 3 little-endian bytes — the date part of the AE
786/// normalized form for `date`, `datetime2`, and `datetimeoffset`.
787#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
788fn ae_date_bytes(d: chrono::NaiveDate) -> [u8; 3] {
789    use chrono::Datelike;
790    let days = (d.num_days_from_ce() - 1) as u32;
791    let b = days.to_le_bytes();
792    [b[0], b[1], b[2]]
793}
794
795/// The AE normalized form for `time(scale)`: a little-endian count of
796/// `10^-scale`-second ticks since midnight, in 3/4/5 bytes for scale 0–2/3–4/5–7
797/// (matching SQL Server's `time` storage). Sub-scale digits are rounded.
798#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
799fn normalize_ae_time(t: chrono::NaiveTime, scale: u8) -> Result<Vec<u8>, EncryptionError> {
800    use chrono::Timelike;
801    if scale > 7 {
802        return Err(EncryptionError::UnsupportedOperation(format!(
803            "time scale {scale} out of range (0–7)"
804        )));
805    }
806    let nanos =
807        u64::from(t.num_seconds_from_midnight()) * 1_000_000_000 + u64::from(t.nanosecond());
808    let divisor = 10u64.pow(9 - u32::from(scale));
809    let ticks = (nanos + divisor / 2) / divisor;
810    let len = match scale {
811        0..=2 => 3,
812        3..=4 => 4,
813        _ => 5,
814    };
815    Ok(ticks.to_le_bytes()[..len].to_vec())
816}
817
818/// AE normalized `datetime2(scale)`: `time(scale)` ticks followed by the
819/// 3-byte date.
820#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
821fn normalize_ae_datetime2(
822    dt: chrono::NaiveDateTime,
823    scale: u8,
824) -> Result<Vec<u8>, EncryptionError> {
825    let mut out = normalize_ae_time(dt.time(), scale)?;
826    out.extend_from_slice(&ae_date_bytes(dt.date()));
827    Ok(out)
828}
829
830/// AE normalized `datetimeoffset(scale)`: the UTC `time(scale)` ticks, the
831/// 3-byte UTC date, then the offset in minutes as a 2-byte little-endian i16.
832#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
833fn normalize_ae_datetimeoffset(
834    dto: chrono::DateTime<chrono::FixedOffset>,
835    scale: u8,
836) -> Result<Vec<u8>, EncryptionError> {
837    use chrono::Offset;
838    let utc = dto.naive_utc();
839    let mut out = normalize_ae_time(utc.time(), scale)?;
840    out.extend_from_slice(&ae_date_bytes(utc.date()));
841    let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
842    out.extend_from_slice(&offset_minutes.to_le_bytes());
843    Ok(out)
844}
845
846/// The MONEY fixed-point value (`value * 10_000`) as an `i64`, rounding excess
847/// precision toward zero. Used by both MONEY and SMALLMONEY normalization.
848#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
849fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
850    let mantissa = value.mantissa();
851    let scale = value.scale();
852    let cents: i128 = if scale <= 4 {
853        mantissa
854            .checked_mul(10_i128.pow(4 - scale))
855            .ok_or_else(|| {
856                EncryptionError::UnsupportedOperation("MONEY value out of range".into())
857            })?
858    } else {
859        mantissa / 10_i128.pow(scale - 4)
860    };
861    i64::try_from(cents)
862        .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
863}
864
865#[cfg(test)]
866#[allow(clippy::unwrap_used, clippy::expect_used)]
867mod tests {
868    use super::*;
869
870    /// Reference ciphertexts captured from a live deterministic Always Encrypted
871    /// INSERT via Microsoft.Data.SqlClient 5.2.2. Encrypting our normalization
872    /// with the same CEK must reproduce them byte-for-byte — proving the
873    /// normalized layout matches the real .NET client (notably INT -> 8 LE bytes,
874    /// which is the layout a naive implementation would get wrong).
875    #[cfg(feature = "always-encrypted")]
876    #[test]
877    fn ae_normalization_matches_dotnet() {
878        use bytes::Bytes;
879
880        fn unhex(s: &str) -> Vec<u8> {
881            (0..s.len())
882                .step_by(2)
883                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
884                .collect()
885        }
886
887        let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
888        let enc = AeadEncryptor::new(&cek).unwrap();
889
890        for (value, reference) in [
891            (
892                SqlValue::Int(42),
893                "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
894            ),
895            (
896                SqlValue::String("Ada".to_string()),
897                "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
898            ),
899            (
900                SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
901                "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
902            ),
903        ] {
904            let norm = normalize_for_encryption(&value, None).unwrap();
905            let cipher = enc
906                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
907                .unwrap();
908            assert_eq!(
909                cipher,
910                unhex(reference),
911                "ciphertext for {} must match Microsoft.Data.SqlClient",
912                value.type_name()
913            );
914        }
915    }
916
917    /// `normalize_for_encryption` rejects values it has no normalization for
918    /// rather than silently producing wrong bytes. NULL is never normalized
919    /// (it is handled as a NULL parameter upstream), so it exercises the
920    /// catch-all rejection arm and stays unsupported as more types are added.
921    #[cfg(feature = "always-encrypted")]
922    #[test]
923    fn ae_normalization_rejects_unnormalizable_value() {
924        assert!(normalize_for_encryption(&SqlValue::Null, None).is_err());
925    }
926
927    /// Numeric-scalar normalization, validated byte-for-byte against
928    /// Microsoft.Data.SqlClient (same method as [`ae_normalization_matches_dotnet`],
929    /// captured with a fresh CEK). This is the interop guarantee: a value the
930    /// driver encrypts is the value .NET would encrypt. Notable: every integer
931    /// width and bit normalize to 8 bytes, real to 4, float to 8.
932    #[cfg(feature = "always-encrypted")]
933    #[test]
934    fn ae_normalization_matches_dotnet_numeric() {
935        fn unhex(s: &str) -> Vec<u8> {
936            (0..s.len())
937                .step_by(2)
938                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
939                .collect()
940        }
941
942        let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
943        let enc = AeadEncryptor::new(&cek).unwrap();
944
945        for (value, reference) in [
946            (
947                SqlValue::BigInt(0x0102030405060708),
948                "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
949            ),
950            (
951                SqlValue::SmallInt(258),
952                "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
953            ),
954            (
955                SqlValue::TinyInt(200),
956                "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
957            ),
958            (
959                SqlValue::Bool(true),
960                "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
961            ),
962            (
963                SqlValue::Float(3.5),
964                "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
965            ),
966            (
967                SqlValue::Double(3.5),
968                "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
969            ),
970        ] {
971            let norm = normalize_for_encryption(&value, None).unwrap();
972            let cipher = enc
973                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
974                .unwrap();
975            assert_eq!(
976                cipher,
977                unhex(reference),
978                "ciphertext for {} must match Microsoft.Data.SqlClient",
979                value.type_name()
980            );
981        }
982    }
983
984    /// UUID and DATE normalization, validated byte-for-byte against
985    /// Microsoft.Data.SqlClient: uuid uses SQL Server's mixed-endian GUID byte
986    /// order, date is a 3-byte little-endian day count since 0001-01-01.
987    #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
988    #[test]
989    fn ae_normalization_matches_dotnet_uuid_date() {
990        fn unhex(s: &str) -> Vec<u8> {
991            (0..s.len())
992                .step_by(2)
993                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
994                .collect()
995        }
996
997        let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
998        let enc = AeadEncryptor::new(&cek).unwrap();
999
1000        for (value, reference) in [
1001            (
1002                SqlValue::Uuid(
1003                    uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
1004                ),
1005                "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
1006            ),
1007            (
1008                SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
1009                "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
1010            ),
1011        ] {
1012            let norm = normalize_for_encryption(&value, None).unwrap();
1013            let cipher = enc
1014                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
1015                .unwrap();
1016            assert_eq!(
1017                cipher,
1018                unhex(reference),
1019                "ciphertext for {} must match Microsoft.Data.SqlClient",
1020                value.type_name()
1021            );
1022        }
1023    }
1024
1025    /// DECIMAL and MONEY/SMALLMONEY normalization, validated byte-for-byte
1026    /// against Microsoft.Data.SqlClient: decimal is a sign byte plus a 16-byte
1027    /// little-endian unscaled magnitude; money and smallmoney both use the
1028    /// 8-byte MONEY form (value × 10_000, high then low 32 bits).
1029    #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
1030    #[test]
1031    fn ae_normalization_matches_dotnet_decimal_money() {
1032        fn unhex(s: &str) -> Vec<u8> {
1033            (0..s.len())
1034                .step_by(2)
1035                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1036                .collect()
1037        }
1038
1039        let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
1040        let enc = AeadEncryptor::new(&cek).unwrap();
1041        let dec = rust_decimal::Decimal::new(123_456_789, 4); // 12345.6789
1042        let money = rust_decimal::Decimal::new(123_400, 4); // 12.3400
1043
1044        for (value, reference) in [
1045            (
1046                SqlValue::Decimal(dec),
1047                "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
1048            ),
1049            (
1050                SqlValue::Money(money),
1051                "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
1052            ),
1053            (
1054                SqlValue::SmallMoney(money),
1055                "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
1056            ),
1057        ] {
1058            let norm = normalize_for_encryption(&value, None).unwrap();
1059            let cipher = enc
1060                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
1061                .unwrap();
1062            assert_eq!(
1063                cipher,
1064                unhex(reference),
1065                "ciphertext for {} must match Microsoft.Data.SqlClient",
1066                value.type_name()
1067            );
1068        }
1069    }
1070
1071    /// Temporal AE normalization, validated byte-for-byte against the forms
1072    /// `Microsoft.Data.SqlClient` produces (decrypted from its ciphertext;
1073    /// comparing the normalized plaintext is equivalent to comparing ciphertext
1074    /// because AEAD is deterministic). Scale 7 here; lower scales are covered by
1075    /// the live round-trip + `_temporal_scales` below.
1076    #[cfg(all(feature = "always-encrypted", feature = "chrono"))]
1077    #[test]
1078    fn ae_normalization_matches_dotnet_temporal() {
1079        use mssql_types::EncryptedParamType as E;
1080        fn unhex(s: &str) -> Vec<u8> {
1081            (0..s.len())
1082                .step_by(2)
1083                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1084                .collect()
1085        }
1086
1087        let day = chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap();
1088        let dt = day.and_hms_nano_opt(13, 14, 15, 123_456_700).unwrap();
1089
1090        // time(7)
1091        assert_eq!(
1092            normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 7 }))
1093                .unwrap(),
1094            unhex("07c4aaf46e"),
1095        );
1096        // datetime2(7)
1097        assert_eq!(
1098            normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 7 }))
1099                .unwrap(),
1100            unhex("07c4aaf46e8f460b"),
1101        );
1102        // datetimeoffset(7) +05:30 — normalized as UTC time + UTC date + offset minutes
1103        let dto = {
1104            use chrono::TimeZone;
1105            chrono::FixedOffset::east_opt(5 * 3600 + 30 * 60)
1106                .unwrap()
1107                .from_local_datetime(&dt)
1108                .single()
1109                .unwrap()
1110        };
1111        assert_eq!(
1112            normalize_for_encryption(
1113                &SqlValue::DateTimeOffset(dto),
1114                Some(E::DateTimeOffset { scale: 7 })
1115            )
1116            .unwrap(),
1117            unhex("0788f2da408f460b4a01"),
1118        );
1119        // legacy datetime
1120        let dt_legacy = day.and_hms_milli_opt(13, 14, 15, 123).unwrap();
1121        assert_eq!(
1122            normalize_for_encryption(&SqlValue::DateTime(dt_legacy), Some(E::DateTime)).unwrap(),
1123            unhex("34b10000d925da00"),
1124        );
1125        // smalldatetime (no scale hint — declared by the SmallDateTime wrapper)
1126        let sdt = day.and_hms_opt(13, 14, 0).unwrap();
1127        assert_eq!(
1128            normalize_for_encryption(&SqlValue::SmallDateTime(sdt), None).unwrap(),
1129            unhex("34b11a03"),
1130        );
1131    }
1132
1133    /// Fixed-width char/nchar/binary AE normalization, validated byte-for-byte
1134    /// against Microsoft.Data.SqlClient. KEY FACT: the normalized form is the
1135    /// value's bytes, NOT padded to the declared width — char in the column code
1136    /// page (Windows-1252), nchar as UTF-16LE, binary raw.
1137    #[cfg(feature = "always-encrypted")]
1138    #[test]
1139    fn ae_normalization_matches_dotnet_fixed_width() {
1140        use mssql_types::EncryptedParamType as E;
1141        fn unhex(s: &str) -> Vec<u8> {
1142            (0..s.len())
1143                .step_by(2)
1144                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1145                .collect()
1146        }
1147        // char(10) "Hello" → Windows-1252 "Hello" (5 bytes, unpadded)
1148        assert_eq!(
1149            normalize_for_encryption(
1150                &SqlValue::String("Hello".to_string()),
1151                Some(E::Char { length: 10 })
1152            )
1153            .unwrap(),
1154            unhex("48656c6c6f"),
1155        );
1156        // nchar(10) "Hello" → UTF-16LE (10 bytes, unpadded)
1157        assert_eq!(
1158            normalize_for_encryption(
1159                &SqlValue::String("Hello".to_string()),
1160                Some(E::NChar { length: 10 })
1161            )
1162            .unwrap(),
1163            unhex("480065006c006c006f00"),
1164        );
1165        // binary(10) [1,2,3,4,5] → raw (5 bytes, unpadded)
1166        assert_eq!(
1167            normalize_for_encryption(
1168                &SqlValue::Binary(bytes::Bytes::from_static(&[1, 2, 3, 4, 5])),
1169                Some(E::Binary { length: 10 })
1170            )
1171            .unwrap(),
1172            unhex("0102030405"),
1173        );
1174    }
1175
1176    #[test]
1177    fn test_encryption_config_defaults() {
1178        let config = EncryptionConfig::new();
1179        assert!(config.enabled);
1180        assert!(config.cache_ceks);
1181        assert!(!config.is_ready()); // No providers
1182    }
1183
1184    #[test]
1185    fn test_result_set_encryption_info() {
1186        let cek_table = CekTable::new();
1187        let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
1188
1189        assert!(!info.is_column_encrypted(0));
1190        assert!(!info.is_column_encrypted(1));
1191        assert!(!info.is_column_encrypted(2));
1192
1193        let metadata = CryptoMetadata {
1194            cek_table_ordinal: 0,
1195            base_user_type: 0,
1196            base_col_type: 0x26,
1197            base_type_info: tds_protocol::token::TypeInfo::default(),
1198            algorithm_id: 2,
1199            encryption_type: EncryptionTypeWire::Deterministic,
1200            normalization_version: 1,
1201        };
1202
1203        info.set_column_crypto(1, metadata);
1204        assert!(!info.is_column_encrypted(0));
1205        assert!(info.is_column_encrypted(1));
1206        assert!(!info.is_column_encrypted(2));
1207
1208        assert_eq!(
1209            info.get_encryption_type(1),
1210            Some(EncryptionTypeWire::Deterministic)
1211        );
1212    }
1213
1214    #[test]
1215    fn test_parameter_encryption_info() {
1216        let mut info = ParameterEncryptionInfo::new();
1217
1218        assert!(!info.needs_encryption("@p1"));
1219
1220        let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1);
1221        info.add_parameter("@p1".to_string(), crypto);
1222
1223        assert!(info.needs_encryption("@p1"));
1224        assert!(!info.needs_encryption("@p2"));
1225
1226        let param = info.get_parameter("@p1").unwrap();
1227        assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
1228    }
1229
1230    /// Parse synthetic `sp_describe_parameter_encryption` result sets that mirror
1231    /// the live wire shape (captured in `.tmp/ae-3a2-describe-schema.md`). The
1232    /// column *order* is validated separately by the live test; this exercises
1233    /// the logic the live single-CEK/single-CMK case cannot: grouping multiple
1234    /// CMK-wrappings under one CEK, translating the server's (1-based) key
1235    /// ordinal to a positional index, little-endian `binary(8)` md-version
1236    /// decode, and skipping plaintext parameters.
1237    #[cfg(feature = "always-encrypted")]
1238    #[test]
1239    fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1240        use crate::row::{Column, Row};
1241        use crate::stream::ResultSet;
1242        use bytes::Bytes;
1243
1244        fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1245            let cols: Vec<Column> = (0..n_cols)
1246                .map(|i| Column::new(format!("c{i}"), i, "x"))
1247                .collect();
1248            let rows = rows
1249                .into_iter()
1250                .map(|vals| Row::from_values(cols.clone(), vals))
1251                .collect();
1252            ResultSet::new(cols, rows)
1253        }
1254
1255        let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); // -> 1
1256        let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); // -> 255
1257
1258        // RS1: CEK ordinal 1 wrapped by two CMKs (rotation), plus CEK ordinal 2.
1259        let rs1 = rs(
1260            9,
1261            vec![
1262                vec![
1263                    SqlValue::Int(1),
1264                    SqlValue::Int(7),
1265                    SqlValue::Int(56),
1266                    SqlValue::Int(1),
1267                    SqlValue::Binary(mdv1.clone()),
1268                    SqlValue::Binary(Bytes::from_static(b"env-a")),
1269                    SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1270                    SqlValue::String("path-a".into()),
1271                    SqlValue::String("RSA_OAEP".into()),
1272                ],
1273                vec![
1274                    SqlValue::Int(1),
1275                    SqlValue::Int(7),
1276                    SqlValue::Int(56),
1277                    SqlValue::Int(1),
1278                    SqlValue::Binary(mdv1),
1279                    SqlValue::Binary(Bytes::from_static(b"env-a2")),
1280                    SqlValue::String("PROV_2".into()),
1281                    SqlValue::String("path-a2".into()),
1282                    SqlValue::String("RSA_OAEP".into()),
1283                ],
1284                vec![
1285                    SqlValue::Int(2),
1286                    SqlValue::Int(7),
1287                    SqlValue::Int(57),
1288                    SqlValue::Int(1),
1289                    SqlValue::Binary(mdv2),
1290                    SqlValue::Binary(Bytes::from_static(b"env-b")),
1291                    SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1292                    SqlValue::String("path-b".into()),
1293                    SqlValue::String("RSA_OAEP".into()),
1294                ],
1295            ],
1296        );
1297
1298        // RS2: @det on CEK ordinal 1, @rand on CEK ordinal 2, @plain plaintext.
1299        let rs2 = rs(
1300            6,
1301            vec![
1302                vec![
1303                    SqlValue::Int(1),
1304                    SqlValue::String("@det".into()),
1305                    SqlValue::TinyInt(2),
1306                    SqlValue::TinyInt(1),
1307                    SqlValue::Int(1),
1308                    SqlValue::TinyInt(1),
1309                ],
1310                vec![
1311                    SqlValue::Int(2),
1312                    SqlValue::String("@rand".into()),
1313                    SqlValue::TinyInt(2),
1314                    SqlValue::TinyInt(2),
1315                    SqlValue::Int(2),
1316                    SqlValue::TinyInt(1),
1317                ],
1318                vec![
1319                    SqlValue::Int(3),
1320                    SqlValue::String("@plain".into()),
1321                    SqlValue::TinyInt(0),
1322                    SqlValue::TinyInt(0),
1323                    SqlValue::Int(0),
1324                    SqlValue::TinyInt(0),
1325                ],
1326            ],
1327        );
1328
1329        let mut sets = vec![rs1, rs2];
1330        let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1331
1332        assert_eq!(info.cek_table.len(), 2);
1333        let e0 = info.cek_table.get(0).unwrap();
1334        assert_eq!(e0.cek_id, 56);
1335        assert_eq!(e0.cek_md_version, 1);
1336        assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1337        assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1338        assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1339        let e1 = info.cek_table.get(1).unwrap();
1340        assert_eq!(e1.cek_id, 57);
1341        assert_eq!(e1.cek_md_version, 255);
1342
1343        let det = info.get_parameter("@det").unwrap();
1344        assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1345        assert_eq!(det.algorithm_id, 2);
1346        assert_eq!(det.normalization_rule_version, 1);
1347        assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1348
1349        let rand = info.get_parameter("@rand").unwrap();
1350        assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1351        assert_eq!(
1352            rand.cek_ordinal, 1,
1353            "server ordinal 2 -> positional index 1"
1354        );
1355
1356        assert!(!info.needs_encryption("@plain"));
1357        assert_eq!(info.parameters.len(), 2);
1358    }
1359
1360    /// A truncated response (fewer than two result sets) must be rejected, not
1361    /// silently treated as "no parameters need encryption".
1362    #[cfg(feature = "always-encrypted")]
1363    #[test]
1364    fn parse_describe_result_sets_rejects_missing_result_set() {
1365        use crate::row::{Column, Row};
1366        use crate::stream::ResultSet;
1367
1368        let cols: Vec<Column> = (0..9)
1369            .map(|i| Column::new(format!("c{i}"), i, "x"))
1370            .collect();
1371        let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1372        assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1373    }
1374}