Skip to main content

mssql_client/
encryption.rs

1//! Always Encrypted client-side encryption and decryption.
2//!
3//! This module provides the infrastructure for SQL Server's Always Encrypted feature,
4//! which enables client-side encryption of sensitive database columns.
5//!
6//! ## Architecture
7//!
8//! Always Encrypted uses a two-tier key hierarchy:
9//!
10//! ```text
11//! Column Master Key (CMK) - External (KeyVault, CertStore, HSM)
12//!         │
13//!         ▼ RSA-OAEP unwrap
14//! Column Encryption Key (CEK) - Stored encrypted in database
15//!         │
16//!         ▼ AEAD_AES_256_CBC_HMAC_SHA256
17//! Encrypted Column Data
18//! ```
19//!
20//! ## Usage
21//!
22//! ```rust,no_run
23//! # async fn with_always_encrypted() -> Result<(), Box<dyn std::error::Error>> {
24//! # #[cfg(feature = "always-encrypted")]
25//! # {
26//! # let conn_str = "Server=localhost;Database=db;Encrypt=strict;Column Encryption Setting=Enabled";
27//! use mssql_client::{Client, Config, EncryptionConfig};
28//! use mssql_auth::InMemoryKeyStore;
29//!
30//! // Register a key-store provider, then attach it to the connection config.
31//! let key_store = InMemoryKeyStore::new();
32//! let encryption_config = EncryptionConfig::new().with_provider(key_store);
33//!
34//! let config = Config::from_connection_string(conn_str)?
35//!     .with_column_encryption(encryption_config);
36//!
37//! let _client = Client::connect(config).await?;
38//! # }
39//! # Ok(())
40//! # }
41//! ```
42//!
43//! Equivalently, set `Column Encryption Setting=Enabled` in the connection
44//! string. Production-ready providers ship in `mssql-auth`: `InMemoryKeyStore`
45//! (dev/test), `AzureKeyVaultProvider` (`azure-identity` feature), and
46//! `WindowsCertStoreProvider` (`sspi-auth`, Windows). Implement
47//! [`mssql_auth::KeyStoreProvider`] for custom key storage. Do **not** substitute
48//! T-SQL `ENCRYPTBYKEY` — the server can see that plaintext, defeating the point.
49//!
50//! ## How decryption works
51//!
52//! 1. Always Encrypted support is negotiated in LOGIN7 (`FEATURE_EXT`).
53//! 2. `ColMetaData` carries [`CryptoMetadata`](tds_protocol::crypto::CryptoMetadata) and the [`CekTable`]; column
54//!    encryption keys are resolved asynchronously up front (calling the key-store
55//!    providers).
56//! 3. Each encrypted cell is decrypted during row parsing via
57//!    AEAD_AES_256_CBC_HMAC_SHA256, with the HMAC verified before decryption.
58//!
59//! Reads are transparent across `query`, `call_procedure`, the procedure
60//! builder, and multi-result queries. Parameter (write) encryption is wired
61//! into parameterized `query`/`execute` for the common scalar types — `int`,
62//! `tinyint`, `smallint`, `bigint`, `bit`, `real`, `float`, `nvarchar`,
63//! `varbinary`, `uniqueidentifier`, `date`, `money`, `smallmoney`, `decimal`
64//! (via `numeric(value, precision, scale)`), and typed `NULL` (via
65//! `null::<T>()`): with `Column Encryption Setting=Enabled` the
66//! driver describes the parameters (`sp_describe_parameter_encryption`),
67//! encrypts those bound to encrypted columns client-side, and sends them as
68//! encrypted RPC parameters (deterministic and randomized). The temporal and
69//! fixed-width types are supported through typed-parameter wrappers —
70//! `time(v, scale)`, `datetime2(v, scale)`, `datetimeoffset(v, scale)`,
71//! `datetime(v)` (legacy), the `SmallDateTime` wrapper, and `char(v, len)` /
72//! `nchar(v, len)` / `binary(v, len)`. The full scalar/temporal/fixed-width set
73//! is now covered.
74//!
75//! Bind a `decimal` parameter with `numeric(value, precision, scale)`, not a
76//! plain `Decimal`, and a scaled temporal or fixed-width value with its wrapper,
77//! not a bare `NaiveDateTime`/`NaiveTime`/`String`: an encrypted column requires
78//! the declared type — precision/scale/length included — to match the column
79//! exactly, which a bare value can't convey, so the server rejects it with
80//! `Operand type clash` (Msg 206) at the describe step. Encrypted `decimal` is
81//! bounded to `scale ≤ 28` — `rust_decimal` cannot represent more fractional
82//! digits, so `numeric()` rejects a declared `scale > 28` (also `precision`
83//! outside `1..=38` and `scale > precision`) rather than emit a magnitude a
84//! Microsoft client reads at the wrong scale; this is decimal AE support
85//! bounded to `scale ≤ 28`, not full `decimal(38, s)` coverage. Encrypted
86//! `char`/`nchar` columns must use a `*_BIN2` collation (a SQL Server
87//! requirement for deterministic encryption of character types); `char` is
88//! encoded as Windows-1252, and a `char` value containing a character absent
89//! from Windows-1252 (e.g. `中`) is rejected rather than silently substituted
90//! (use `nchar` for non-Latin text). Temporal values are normalized to Always Encrypted's
91//! fixed-width form (time = 5 bytes, datetime2 = 8, datetimeoffset = 10, the
92//! value truncated to the column scale but stored at scale-7 width), matching
93//! `Microsoft.Data.SqlClient` and validated byte-for-byte against it at both
94//! scale 7 and scale 3.
95//!
96//! ## Security Model
97//!
98//! - **Client-only decryption**: SQL Server never sees plaintext data
99//! - **DBA protection**: Even database administrators cannot read encrypted data
100//! - **Key separation**: CMK stays in secure key store, never transmitted
101
102#[cfg(feature = "always-encrypted")]
103use std::collections::HashMap;
104
105use mssql_auth::KeyStoreProvider;
106#[cfg(feature = "always-encrypted")]
107use tds_protocol::crypto::{CekTable, CekTableEntry, EncryptionTypeWire};
108
109#[cfg(feature = "always-encrypted")]
110use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
111#[cfg(feature = "always-encrypted")]
112use mssql_types::SqlValue;
113#[cfg(feature = "always-encrypted")]
114use std::sync::Arc;
115
116#[cfg(feature = "always-encrypted")]
117use crate::{Error, row::Row, stream::ResultSet};
118#[cfg(feature = "always-encrypted")]
119use tds_protocol::crypto::CekValue;
120
121/// Configuration for Always Encrypted feature.
122#[derive(Default)]
123pub struct EncryptionConfig {
124    /// Whether encryption is enabled.
125    pub enabled: bool,
126    /// Registered key store providers.
127    providers: Vec<Box<dyn KeyStoreProvider>>,
128    /// Whether to cache decrypted CEKs for performance.
129    pub cache_ceks: bool,
130}
131
132impl EncryptionConfig {
133    /// Create a new encryption configuration (disabled by default).
134    #[must_use]
135    pub fn new() -> Self {
136        Self {
137            enabled: true,
138            providers: Vec::new(),
139            cache_ceks: true,
140        }
141    }
142
143    /// Register a key store provider.
144    pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
145        self.providers.push(Box::new(provider));
146    }
147
148    /// Builder method to add a key store provider.
149    #[must_use]
150    pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
151        self.register_provider(provider);
152        self
153    }
154
155    /// Enable or disable CEK caching.
156    #[must_use]
157    pub fn with_cek_caching(mut self, enabled: bool) -> Self {
158        self.cache_ceks = enabled;
159        self
160    }
161
162    /// Get a provider by name.
163    pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
164        self.providers
165            .iter()
166            .find(|p| p.provider_name() == name)
167            .map(|p| p.as_ref())
168    }
169
170    /// Check if encryption is ready (enabled and has providers).
171    #[must_use]
172    pub fn is_ready(&self) -> bool {
173        self.enabled && !self.providers.is_empty()
174    }
175}
176
177impl std::fmt::Debug for EncryptionConfig {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("EncryptionConfig")
180            .field("enabled", &self.enabled)
181            .field("provider_count", &self.providers.len())
182            .field("cache_ceks", &self.cache_ceks)
183            .finish()
184    }
185}
186
187/// Runtime context for encryption operations.
188///
189/// This is the active encryption state for a connected client,
190/// including resolved CEKs and encryptors.
191///
192/// The context holds an `Arc<EncryptionConfig>` so providers remain accessible
193/// across connection retries/redirects where the `Config` (and its inner
194/// encryption config Arc) gets cloned multiple times.
195#[cfg(feature = "always-encrypted")]
196pub(crate) struct EncryptionContext {
197    /// Shared handle on the user-supplied configuration. Providers are looked
198    /// up by name through this reference, so an arbitrary number of `Arc`
199    /// clones do not lose access to them.
200    config: std::sync::Arc<EncryptionConfig>,
201    /// Cache for decrypted CEKs.
202    cek_cache: CekCache,
203    /// Whether caching is enabled.
204    cache_enabled: bool,
205}
206
207#[cfg(feature = "always-encrypted")]
208impl EncryptionContext {
209    /// Create a new encryption context from an Arc-wrapped configuration.
210    ///
211    /// The Arc is retained by the context so provider lookups continue to
212    /// work for the lifetime of the client — regardless of how many times
213    /// the outer `Config` has been cloned for retry/redirect handling.
214    pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
215        let cache_enabled = config.cache_ceks;
216        Self {
217            config,
218            cek_cache: CekCache::new(),
219            cache_enabled,
220        }
221    }
222
223    /// Get or decrypt a CEK for a column.
224    ///
225    /// This handles the CEK caching and decryption logic:
226    /// 1. Check cache for existing encryptor
227    /// 2. If not cached, decrypt CEK using the appropriate key store
228    /// 3. Create and cache the encryptor
229    pub(crate) async fn get_encryptor(
230        &self,
231        cek_entry: &CekTableEntry,
232    ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
233        let cache_key = CekCacheKey::new(
234            cek_entry.database_id,
235            cek_entry.cek_id,
236            cek_entry.cek_version,
237        );
238
239        // Check cache first
240        if self.cache_enabled {
241            if let Some(encryptor) = self.cek_cache.get(&cache_key) {
242                return Ok(encryptor);
243            }
244        }
245
246        // Get the primary CEK value
247        let cek_value = cek_entry
248            .primary_value()
249            .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
250
251        // Find the appropriate key store provider via the shared config
252        let provider = self
253            .config
254            .get_provider(&cek_value.key_store_provider_name)
255            .ok_or_else(|| {
256                EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
257            })?;
258
259        // Decrypt the CEK
260        let decrypted_cek = provider
261            .decrypt_cek(
262                &cek_value.cmk_path,
263                &cek_value.encryption_algorithm,
264                &cek_value.encrypted_value,
265            )
266            .await?;
267
268        // Create encryptor and cache it
269        if self.cache_enabled {
270            self.cek_cache.insert(cache_key, decrypted_cek)
271        } else {
272            // Create encryptor without caching
273            Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
274        }
275    }
276
277    /// Encrypt a value for a column.
278    ///
279    /// # Arguments
280    ///
281    /// * `plaintext` - The plaintext value to encrypt
282    /// * `cek_entry` - The CEK table entry for this column
283    /// * `encryption_type` - Deterministic or randomized encryption
284    pub(crate) async fn encrypt_value(
285        &self,
286        plaintext: &[u8],
287        cek_entry: &CekTableEntry,
288        encryption_type: EncryptionTypeWire,
289    ) -> Result<Vec<u8>, EncryptionError> {
290        let encryptor = self.get_encryptor(cek_entry).await?;
291
292        let enc_type = match encryption_type {
293            EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
294            EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
295            _ => {
296                return Err(EncryptionError::UnsupportedOperation(format!(
297                    "unsupported encryption type: {encryption_type:?}"
298                )));
299            }
300        };
301
302        encryptor.encrypt(plaintext, enc_type)
303    }
304
305    /// Check if a provider is registered.
306    pub fn has_provider(&self, name: &str) -> bool {
307        self.config.get_provider(name).is_some()
308    }
309}
310
311#[cfg(feature = "always-encrypted")]
312impl std::fmt::Debug for EncryptionContext {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        f.debug_struct("EncryptionContext")
315            .field("provider_count", &self.config.providers.len())
316            .field("cache_entries", &self.cek_cache.len())
317            .field("cache_enabled", &self.cache_enabled)
318            .finish()
319    }
320}
321
322/// Parameter encryption metadata for a query.
323///
324/// This is returned by `sp_describe_parameter_encryption` and describes
325/// how each parameter should be encrypted.
326#[cfg(feature = "always-encrypted")]
327#[derive(Debug, Clone)]
328pub(crate) struct ParameterEncryptionInfo {
329    /// CEK table for parameters.
330    pub cek_table: CekTable,
331    /// Mapping from parameter name to crypto metadata.
332    pub parameters: HashMap<String, ParameterCryptoInfo>,
333}
334
335#[cfg(feature = "always-encrypted")]
336impl ParameterEncryptionInfo {
337    /// Create empty parameter encryption info.
338    pub fn new() -> Self {
339        Self {
340            cek_table: CekTable::new(),
341            parameters: HashMap::new(),
342        }
343    }
344
345    /// Get encryption info for a parameter.
346    pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
347        self.parameters.get(name)
348    }
349}
350
351#[cfg(feature = "always-encrypted")]
352impl Default for ParameterEncryptionInfo {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358/// Encryption directive for a single parameter, parsed from result set 2 of
359/// `sp_describe_parameter_encryption`.
360#[cfg(feature = "always-encrypted")]
361#[derive(Debug, Clone)]
362pub(crate) struct ParameterCryptoInfo {
363    /// 0-based index into [`ParameterEncryptionInfo::cek_table`].
364    ///
365    /// The server reports a (often 1-based) key ordinal; the parser translates
366    /// it to this positional index so `cek_table.get(cek_ordinal)` resolves the
367    /// entry directly.
368    pub cek_ordinal: u16,
369    /// Encryption type (deterministic or randomized).
370    pub encryption_type: EncryptionTypeWire,
371    /// Encryption algorithm ID (2 = AEAD_AES_256_CBC_HMAC_SHA256).
372    pub algorithm_id: u8,
373    /// Normalization rule version applied to the plaintext before encryption.
374    pub normalization_rule_version: u8,
375}
376
377/// Parsing of the two result sets returned by `sp_describe_parameter_encryption`.
378///
379/// Result set 1 is the CEK table (one row per CMK-wrapping of each CEK); result
380/// set 2 lists, per parameter, how the server expects it encrypted. The column
381/// layout was captured against a live server (SQL Server 2016+): the first nine
382/// RS1 columns are stable across versions; SQL Server 2019+ append two enclave
383/// columns (`is_requested_by_enclave`, `column_master_key_signature`) which this
384/// non-enclave path ignores. Columns are read positionally to match the
385/// `Microsoft.Data.SqlClient` ordinals.
386#[cfg(feature = "always-encrypted")]
387impl ParameterEncryptionInfo {
388    /// Minimum RS1 column count (SQL Server 2017 returns exactly this; 2019+
389    /// return more, with the extra columns appended after these).
390    const RS1_MIN_COLS: usize = 9;
391    /// RS2 column count, stable across supported versions.
392    const RS2_MIN_COLS: usize = 6;
393
394    /// Parse `sp_describe_parameter_encryption` output into encryption metadata.
395    ///
396    /// `result_sets` must be the `ProcedureResult::result_sets` from that RPC.
397    /// Plaintext parameters (encryption type 0) are omitted from the result.
398    pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
399        if result_sets.len() < 2 {
400            return Err(Error::Protocol(format!(
401                "sp_describe_parameter_encryption returned {} result set(s), expected 2",
402                result_sets.len()
403            )));
404        }
405
406        // --- Result set 1: CEK table ---
407        let rs1_cols = result_sets[0].columns().len();
408        if rs1_cols < Self::RS1_MIN_COLS {
409            return Err(Error::Protocol(format!(
410                "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
411                Self::RS1_MIN_COLS
412            )));
413        }
414        let rs1_rows = result_sets[0].collect_all()?;
415
416        let mut entries: Vec<CekTableEntry> = Vec::new();
417        // Server-assigned key ordinal -> positional index into `entries`.
418        let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
419
420        for row in &rs1_rows {
421            let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
422            let value = CekValue {
423                encrypted_value: describe_varbinary(
424                    row,
425                    5,
426                    "column_encryption_key_encrypted_value",
427                )?,
428                key_store_provider_name: describe_nvarchar(
429                    row,
430                    6,
431                    "column_master_key_store_provider_name",
432                )?,
433                cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
434                encryption_algorithm: describe_nvarchar(
435                    row,
436                    8,
437                    "column_encryption_key_encryption_algorithm_name",
438                )?,
439            };
440
441            if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
442                // Another CMK-wrapping of an already-seen CEK (key rotation).
443                entries[idx as usize].values.push(value);
444            } else {
445                let idx = u16::try_from(entries.len()).map_err(|_| {
446                    Error::Protocol(
447                        "sp_describe_parameter_encryption returned too many CEKs".into(),
448                    )
449                })?;
450                ordinal_to_index.insert(key_ordinal, idx);
451                entries.push(CekTableEntry {
452                    database_id: describe_int(row, 1, "database_id")? as u32,
453                    cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
454                    cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
455                    cek_md_version: describe_md_version(row, 4)?,
456                    values: vec![value],
457                });
458            }
459        }
460        let cek_table = CekTable { entries };
461
462        // --- Result set 2: per-parameter directives ---
463        let rs2_cols = result_sets[1].columns().len();
464        if rs2_cols < Self::RS2_MIN_COLS {
465            return Err(Error::Protocol(format!(
466                "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
467                Self::RS2_MIN_COLS
468            )));
469        }
470        let rs2_rows = result_sets[1].collect_all()?;
471
472        let mut parameters = HashMap::new();
473        for row in &rs2_rows {
474            let name = describe_nvarchar(row, 1, "parameter_name")?;
475            let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
476            // 0 = the server determined this parameter needs no encryption.
477            if encryption_type_byte == 0 {
478                continue;
479            }
480            let encryption_type =
481                EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
482                    Error::Protocol(format!(
483                        "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
484                    ))
485                })?;
486            let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
487            let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
488            let normalization_rule_version =
489                describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
490
491            let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
492                Error::Protocol(format!(
493                    "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
494                ))
495            })?;
496
497            parameters.insert(
498                name,
499                ParameterCryptoInfo {
500                    cek_ordinal,
501                    encryption_type,
502                    algorithm_id,
503                    normalization_rule_version,
504                },
505            );
506        }
507
508        Ok(Self {
509            cek_table,
510            parameters,
511        })
512    }
513}
514
515/// Read an `int` describe column, erroring if it is absent or a different type.
516#[cfg(feature = "always-encrypted")]
517fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
518    match row.get_raw(idx) {
519        Some(SqlValue::Int(v)) => Ok(v),
520        other => Err(describe_type_error(col, idx, "int", other.as_ref())),
521    }
522}
523
524/// Read a `tinyint` describe column.
525#[cfg(feature = "always-encrypted")]
526fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
527    match row.get_raw(idx) {
528        Some(SqlValue::TinyInt(v)) => Ok(v),
529        other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
530    }
531}
532
533/// Read an `nvarchar` describe column.
534#[cfg(feature = "always-encrypted")]
535fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
536    match row.get_raw(idx) {
537        Some(SqlValue::String(v)) => Ok(v),
538        other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
539    }
540}
541
542/// Read a `varbinary` describe column.
543#[cfg(feature = "always-encrypted")]
544fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
545    match row.get_raw(idx) {
546        Some(SqlValue::Binary(v)) => Ok(v),
547        other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
548    }
549}
550
551/// Read the `binary(8)` metadata-version column as a little-endian `u64`.
552#[cfg(feature = "always-encrypted")]
553fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
554    match row.get_raw(idx) {
555        Some(SqlValue::Binary(b)) if b.len() == 8 => {
556            let mut bytes = [0u8; 8];
557            bytes.copy_from_slice(&b[..8]);
558            Ok(u64::from_le_bytes(bytes))
559        }
560        other => Err(describe_type_error(
561            "column_encryption_key_metadata_version",
562            idx,
563            "binary(8)",
564            other.as_ref(),
565        )),
566    }
567}
568
569/// Build a uniform error for an unexpected describe-column type.
570#[cfg(feature = "always-encrypted")]
571fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
572    let got = got.map_or("missing", SqlValue::type_name);
573    Error::Protocol(format!(
574        "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
575    ))
576}
577
578/// Normalize a parameter value to the plaintext byte form Always Encrypted
579/// encrypts — SQL Server's "normalized" form for the value's type. The result
580/// is the plaintext input to `EncryptionContext::encrypt_value`.
581///
582/// Normalization is type-specific and is **not** the regular TDS wire encoding:
583/// e.g. INT normalizes to 8 little-endian bytes (not 4), and strings/binaries
584/// carry no length prefix. These layouts are validated byte-for-byte against
585/// Microsoft.Data.SqlClient (see the `ae_normalization` tests). Only the types
586/// supported so far are handled; others return `UnsupportedOperation`.
587///
588/// Typed temporal parameters (`time`/`datetime2`/`datetimeoffset`/`datetime`)
589/// pass their [`mssql_types::EncryptedParamType`] in `param_type`: their byte
590/// length depends on the column scale, so the value alone is insufficient.
591#[cfg(feature = "always-encrypted")]
592pub fn normalize_for_encryption(
593    value: &SqlValue,
594    param_type: Option<mssql_types::EncryptedParamType>,
595) -> Result<Vec<u8>, EncryptionError> {
596    // CHAR: the value's bytes in the column code page (Windows-1252), unpadded.
597    // NCHAR/BINARY reuse the String/Binary value arms below (UTF-16 / raw).
598    if let (Some(mssql_types::EncryptedParamType::Char { .. }), SqlValue::String(s)) =
599        (param_type, value)
600    {
601        let (encoded, _, had_errors) = encoding_rs::WINDOWS_1252.encode(s);
602        // A code point absent from Windows-1252 is substituted by encoding_rs
603        // with a numeric character reference (`&#NNNN;`), which is byte garbage
604        // no Microsoft client can read back (.NET stores `?`). Erroring converts
605        // silent Always Encrypted corruption into a clear failure rather than
606        // guessing at a lossy substitution. (`char` columns are Windows-1252
607        // only; use `nchar` for non-Latin text.)
608        if had_errors {
609            return Err(EncryptionError::EncryptionFailed(
610                "char value contains characters not representable in Windows-1252 \
611                 (the char column code page); use nchar for non-Latin text"
612                    .to_string(),
613            ));
614        }
615        return Ok(encoded.into_owned());
616    }
617    // Typed temporal parameters carry the column scale (the encrypted byte
618    // length depends on it), so they're handled from the hint, not the value.
619    #[cfg(feature = "chrono")]
620    {
621        use mssql_types::EncryptedParamType as E;
622        match (param_type, value) {
623            (Some(E::Time { scale }), SqlValue::Time(t)) => return normalize_ae_time(*t, scale),
624            (Some(E::DateTime2 { scale }), SqlValue::DateTime(dt)) => {
625                return normalize_ae_datetime2(*dt, scale);
626            }
627            (Some(E::DateTimeOffset { scale }), SqlValue::DateTimeOffset(dto)) => {
628                return normalize_ae_datetimeoffset(*dto, scale);
629            }
630            (Some(E::DateTime), SqlValue::DateTime(dt)) => {
631                let mut buf = bytes::BytesMut::with_capacity(8);
632                mssql_types::__private::encode_datetime_legacy(*dt, &mut buf);
633                return Ok(buf.to_vec());
634            }
635            _ => {}
636        }
637    }
638    match value {
639        // All integer types AND bit normalize to 8-byte little-endian (the value
640        // widened to i64). Validated against .NET: tinyint/smallint are 8 bytes,
641        // not their native 1/2 — a spec-reading would get this wrong.
642        SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
643        SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
644        SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
645        SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
646        SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
647        // REAL/FLOAT: the IEEE-754 bits, little-endian (4 and 8 bytes).
648        SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
649        SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
650        // NVARCHAR: UTF-16LE code units, no length prefix.
651        SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
652        // VARBINARY: the raw bytes, no length prefix.
653        SqlValue::Binary(b) => Ok(b.to_vec()),
654        // UNIQUEIDENTIFIER: SQL Server's 16-byte mixed-endian GUID order (first
655        // three groups byte-reversed from the RFC layout, last 8 as-is).
656        #[cfg(feature = "uuid")]
657        SqlValue::Uuid(u) => {
658            let b = u.as_bytes();
659            Ok(vec![
660                b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
661                b[13], b[14], b[15],
662            ])
663        }
664        // DATE: 3-byte little-endian count of days since 0001-01-01.
665        // `num_days_from_ce` counts from day 1, so subtract 1.
666        #[cfg(feature = "chrono")]
667        SqlValue::Date(d) => {
668            use chrono::Datelike;
669            let days = (d.num_days_from_ce() - 1) as u32;
670            Ok(days.to_le_bytes()[..3].to_vec())
671        }
672        // DECIMAL/NUMERIC: 1 sign byte (0 negative, 1 positive) + 16-byte
673        // little-endian unscaled magnitude. Uses the value's own scale.
674        #[cfg(feature = "decimal")]
675        SqlValue::Decimal(d) => {
676            let mut out = Vec::with_capacity(17);
677            out.push(u8::from(!d.is_sign_negative()));
678            out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
679            Ok(out)
680        }
681        // MONEY and SMALLMONEY both normalize to the 8-byte MONEY form: the
682        // value scaled by 10_000 as an i64, high 32 bits then low 32 bits.
683        #[cfg(feature = "decimal")]
684        SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
685            let cents = money_cents(d)?;
686            let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
687            out.extend_from_slice(&(cents as u32).to_le_bytes());
688            Ok(out)
689        }
690        // SMALLDATETIME: 2-byte days-since-1900 + 2-byte minutes-since-midnight.
691        // Declared correctly by the `SmallDateTime` wrapper, so no scale hint.
692        #[cfg(feature = "chrono")]
693        SqlValue::SmallDateTime(dt) => {
694            let mut buf = bytes::BytesMut::with_capacity(4);
695            mssql_types::__private::encode_smalldatetime(*dt, &mut buf).map_err(|e| {
696                EncryptionError::UnsupportedOperation(format!("SMALLDATETIME: {e}"))
697            })?;
698            Ok(buf.to_vec())
699        }
700        other => Err(EncryptionError::UnsupportedOperation(format!(
701            "Always Encrypted parameter encryption is not yet implemented for {}",
702            other.type_name()
703        ))),
704    }
705}
706
707/// Days since 0001-01-01 as 3 little-endian bytes — the date part of the AE
708/// normalized form for `date`, `datetime2`, and `datetimeoffset`.
709#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
710fn ae_date_bytes(d: chrono::NaiveDate) -> [u8; 3] {
711    use chrono::Datelike;
712    let days = (d.num_days_from_ce() - 1) as u32;
713    let b = days.to_le_bytes();
714    [b[0], b[1], b[2]]
715}
716
717/// The AE normalized form for `time(scale)`: a little-endian count of
718/// `10^-scale`-second ticks since midnight, in 3/4/5 bytes for scale 0–2/3–4/5–7
719/// (matching SQL Server's `time` storage). Sub-scale digits are rounded.
720#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
721fn normalize_ae_time(t: chrono::NaiveTime, scale: u8) -> Result<Vec<u8>, EncryptionError> {
722    use chrono::Timelike;
723    if scale > 7 {
724        return Err(EncryptionError::UnsupportedOperation(format!(
725            "time scale {scale} out of range (0–7)"
726        )));
727    }
728    let nanos =
729        u64::from(t.num_seconds_from_midnight()) * 1_000_000_000 + u64::from(t.nanosecond());
730    // Always Encrypted normalizes temporal to a FIXED scale-7 (100ns) width — 5
731    // bytes for time — with the value quantized (truncated) to the column scale,
732    // NOT the scale-dependent 3/4/5-byte TDS storage form. This matches
733    // Microsoft.Data.SqlClient's `SerializeTime`, which forces
734    // `length = MAX_TIME_LENGTH = 5` and quantizes via
735    // `ticks / TICKS_FROM_SCALE[scale] * TICKS_FROM_SCALE[scale]`. The two forms
736    // coincide only at scale 7, so the old scale-dependent length emitted
737    // ciphertext no Microsoft client could read at scale 0–6. Validated
738    // byte-exact vs .NET at scale 3 and scale 7.
739    let ticks7 = nanos / 100;
740    let quantum = 10u64.pow(7 - u32::from(scale));
741    let quantized = (ticks7 / quantum) * quantum;
742    Ok(quantized.to_le_bytes()[..5].to_vec())
743}
744
745/// AE normalized `datetime2(scale)`: `time(scale)` ticks followed by the
746/// 3-byte date.
747#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
748fn normalize_ae_datetime2(
749    dt: chrono::NaiveDateTime,
750    scale: u8,
751) -> Result<Vec<u8>, EncryptionError> {
752    let mut out = normalize_ae_time(dt.time(), scale)?;
753    out.extend_from_slice(&ae_date_bytes(dt.date()));
754    Ok(out)
755}
756
757/// AE normalized `datetimeoffset(scale)`: the UTC `time(scale)` ticks, the
758/// 3-byte UTC date, then the offset in minutes as a 2-byte little-endian i16.
759#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
760fn normalize_ae_datetimeoffset(
761    dto: chrono::DateTime<chrono::FixedOffset>,
762    scale: u8,
763) -> Result<Vec<u8>, EncryptionError> {
764    use chrono::Offset;
765    let utc = dto.naive_utc();
766    let mut out = normalize_ae_time(utc.time(), scale)?;
767    out.extend_from_slice(&ae_date_bytes(utc.date()));
768    let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
769    out.extend_from_slice(&offset_minutes.to_le_bytes());
770    Ok(out)
771}
772
773/// The MONEY fixed-point value (`value * 10_000`) as an `i64`, rounding excess
774/// precision toward zero. Used by both MONEY and SMALLMONEY normalization.
775#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
776fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
777    let mantissa = value.mantissa();
778    let scale = value.scale();
779    let cents: i128 = if scale <= 4 {
780        mantissa
781            .checked_mul(10_i128.pow(4 - scale))
782            .ok_or_else(|| {
783                EncryptionError::UnsupportedOperation("MONEY value out of range".into())
784            })?
785    } else {
786        mantissa / 10_i128.pow(scale - 4)
787    };
788    i64::try_from(cents)
789        .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
790}
791
792#[cfg(test)]
793#[allow(clippy::unwrap_used, clippy::expect_used)]
794mod tests {
795    use super::*;
796
797    /// Reference ciphertexts captured from a live deterministic Always Encrypted
798    /// INSERT via Microsoft.Data.SqlClient 5.2.2. Encrypting our normalization
799    /// with the same CEK must reproduce them byte-for-byte — proving the
800    /// normalized layout matches the real .NET client (notably INT -> 8 LE bytes,
801    /// which is the layout a naive implementation would get wrong).
802    #[cfg(feature = "always-encrypted")]
803    #[test]
804    fn ae_normalization_matches_dotnet() {
805        use bytes::Bytes;
806
807        fn unhex(s: &str) -> Vec<u8> {
808            (0..s.len())
809                .step_by(2)
810                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
811                .collect()
812        }
813
814        let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
815        let enc = AeadEncryptor::new(&cek).unwrap();
816
817        for (value, reference) in [
818            (
819                SqlValue::Int(42),
820                "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
821            ),
822            (
823                SqlValue::String("Ada".to_string()),
824                "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
825            ),
826            (
827                SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
828                "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
829            ),
830        ] {
831            let norm = normalize_for_encryption(&value, None).unwrap();
832            let cipher = enc
833                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
834                .unwrap();
835            assert_eq!(
836                cipher,
837                unhex(reference),
838                "ciphertext for {} must match Microsoft.Data.SqlClient",
839                value.type_name()
840            );
841        }
842    }
843
844    /// `normalize_for_encryption` rejects values it has no normalization for
845    /// rather than silently producing wrong bytes. NULL is never normalized
846    /// (it is handled as a NULL parameter upstream), so it exercises the
847    /// catch-all rejection arm and stays unsupported as more types are added.
848    #[cfg(feature = "always-encrypted")]
849    #[test]
850    fn ae_normalization_rejects_unnormalizable_value() {
851        assert!(normalize_for_encryption(&SqlValue::Null, None).is_err());
852    }
853
854    /// Numeric-scalar normalization, validated byte-for-byte against
855    /// Microsoft.Data.SqlClient (same method as [`ae_normalization_matches_dotnet`],
856    /// captured with a fresh CEK). This is the interop guarantee: a value the
857    /// driver encrypts is the value .NET would encrypt. Notable: every integer
858    /// width and bit normalize to 8 bytes, real to 4, float to 8.
859    #[cfg(feature = "always-encrypted")]
860    #[test]
861    fn ae_normalization_matches_dotnet_numeric() {
862        fn unhex(s: &str) -> Vec<u8> {
863            (0..s.len())
864                .step_by(2)
865                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
866                .collect()
867        }
868
869        let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
870        let enc = AeadEncryptor::new(&cek).unwrap();
871
872        for (value, reference) in [
873            (
874                SqlValue::BigInt(0x0102030405060708),
875                "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
876            ),
877            (
878                SqlValue::SmallInt(258),
879                "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
880            ),
881            (
882                SqlValue::TinyInt(200),
883                "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
884            ),
885            (
886                SqlValue::Bool(true),
887                "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
888            ),
889            (
890                SqlValue::Float(3.5),
891                "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
892            ),
893            (
894                SqlValue::Double(3.5),
895                "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
896            ),
897        ] {
898            let norm = normalize_for_encryption(&value, None).unwrap();
899            let cipher = enc
900                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
901                .unwrap();
902            assert_eq!(
903                cipher,
904                unhex(reference),
905                "ciphertext for {} must match Microsoft.Data.SqlClient",
906                value.type_name()
907            );
908        }
909    }
910
911    /// UUID and DATE normalization, validated byte-for-byte against
912    /// Microsoft.Data.SqlClient: uuid uses SQL Server's mixed-endian GUID byte
913    /// order, date is a 3-byte little-endian day count since 0001-01-01.
914    #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
915    #[test]
916    fn ae_normalization_matches_dotnet_uuid_date() {
917        fn unhex(s: &str) -> Vec<u8> {
918            (0..s.len())
919                .step_by(2)
920                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
921                .collect()
922        }
923
924        let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
925        let enc = AeadEncryptor::new(&cek).unwrap();
926
927        for (value, reference) in [
928            (
929                SqlValue::Uuid(
930                    uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
931                ),
932                "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
933            ),
934            (
935                SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
936                "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
937            ),
938        ] {
939            let norm = normalize_for_encryption(&value, None).unwrap();
940            let cipher = enc
941                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
942                .unwrap();
943            assert_eq!(
944                cipher,
945                unhex(reference),
946                "ciphertext for {} must match Microsoft.Data.SqlClient",
947                value.type_name()
948            );
949        }
950    }
951
952    /// DECIMAL and MONEY/SMALLMONEY normalization, validated byte-for-byte
953    /// against Microsoft.Data.SqlClient: decimal is a sign byte plus a 16-byte
954    /// little-endian unscaled magnitude; money and smallmoney both use the
955    /// 8-byte MONEY form (value × 10_000, high then low 32 bits).
956    #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
957    #[test]
958    fn ae_normalization_matches_dotnet_decimal_money() {
959        fn unhex(s: &str) -> Vec<u8> {
960            (0..s.len())
961                .step_by(2)
962                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
963                .collect()
964        }
965
966        let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
967        let enc = AeadEncryptor::new(&cek).unwrap();
968        let dec = rust_decimal::Decimal::new(123_456_789, 4); // 12345.6789
969        let money = rust_decimal::Decimal::new(123_400, 4); // 12.3400
970
971        for (value, reference) in [
972            (
973                SqlValue::Decimal(dec),
974                "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
975            ),
976            (
977                SqlValue::Money(money),
978                "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
979            ),
980            (
981                SqlValue::SmallMoney(money),
982                "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
983            ),
984        ] {
985            let norm = normalize_for_encryption(&value, None).unwrap();
986            let cipher = enc
987                .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
988                .unwrap();
989            assert_eq!(
990                cipher,
991                unhex(reference),
992                "ciphertext for {} must match Microsoft.Data.SqlClient",
993                value.type_name()
994            );
995        }
996    }
997
998    /// Temporal AE normalization, validated byte-for-byte against the forms
999    /// `Microsoft.Data.SqlClient` produces (decrypted from its ciphertext;
1000    /// comparing the normalized plaintext is equivalent to comparing ciphertext
1001    /// because AEAD is deterministic). Scale 7 here; lower scales are covered by
1002    /// the live round-trip + `_temporal_scales` below.
1003    #[cfg(all(feature = "always-encrypted", feature = "chrono"))]
1004    #[test]
1005    fn ae_normalization_matches_dotnet_temporal() {
1006        use mssql_types::EncryptedParamType as E;
1007        fn unhex(s: &str) -> Vec<u8> {
1008            (0..s.len())
1009                .step_by(2)
1010                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1011                .collect()
1012        }
1013
1014        let day = chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap();
1015        let dt = day.and_hms_nano_opt(13, 14, 15, 123_456_700).unwrap();
1016
1017        // time(7)
1018        assert_eq!(
1019            normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 7 }))
1020                .unwrap(),
1021            unhex("07c4aaf46e"),
1022        );
1023        // datetime2(7)
1024        assert_eq!(
1025            normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 7 }))
1026                .unwrap(),
1027            unhex("07c4aaf46e8f460b"),
1028        );
1029        // datetimeoffset(7) +05:30 — normalized as UTC time + UTC date + offset minutes
1030        let dto = {
1031            use chrono::TimeZone;
1032            chrono::FixedOffset::east_opt(5 * 3600 + 30 * 60)
1033                .unwrap()
1034                .from_local_datetime(&dt)
1035                .single()
1036                .unwrap()
1037        };
1038        assert_eq!(
1039            normalize_for_encryption(
1040                &SqlValue::DateTimeOffset(dto),
1041                Some(E::DateTimeOffset { scale: 7 })
1042            )
1043            .unwrap(),
1044            unhex("0788f2da408f460b4a01"),
1045        );
1046
1047        // Scale 3: AE keeps the FIXED 5/8/10-byte width and truncates the value
1048        // to the column scale (.1234567 → .1230000) — the case the old
1049        // scale-dependent length got wrong. Captured from Microsoft.Data.SqlClient.
1050        assert_eq!(
1051            normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 3 }))
1052                .unwrap(),
1053            unhex("30b2aaf46e"),
1054        );
1055        assert_eq!(
1056            normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 3 }))
1057                .unwrap(),
1058            unhex("30b2aaf46e8f460b"),
1059        );
1060        assert_eq!(
1061            normalize_for_encryption(
1062                &SqlValue::DateTimeOffset(dto),
1063                Some(E::DateTimeOffset { scale: 3 })
1064            )
1065            .unwrap(),
1066            unhex("3076f2da408f460b4a01"),
1067        );
1068
1069        // legacy datetime
1070        let dt_legacy = day.and_hms_milli_opt(13, 14, 15, 123).unwrap();
1071        assert_eq!(
1072            normalize_for_encryption(&SqlValue::DateTime(dt_legacy), Some(E::DateTime)).unwrap(),
1073            unhex("34b10000d925da00"),
1074        );
1075        // smalldatetime (no scale hint — declared by the SmallDateTime wrapper)
1076        let sdt = day.and_hms_opt(13, 14, 0).unwrap();
1077        assert_eq!(
1078            normalize_for_encryption(&SqlValue::SmallDateTime(sdt), None).unwrap(),
1079            unhex("34b11a03"),
1080        );
1081    }
1082
1083    /// Fixed-width char/nchar/binary AE normalization, validated byte-for-byte
1084    /// against Microsoft.Data.SqlClient. KEY FACT: the normalized form is the
1085    /// value's bytes, NOT padded to the declared width — char in the column code
1086    /// page (Windows-1252), nchar as UTF-16LE, binary raw.
1087    #[cfg(feature = "always-encrypted")]
1088    #[test]
1089    fn ae_normalization_matches_dotnet_fixed_width() {
1090        use mssql_types::EncryptedParamType as E;
1091        fn unhex(s: &str) -> Vec<u8> {
1092            (0..s.len())
1093                .step_by(2)
1094                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1095                .collect()
1096        }
1097        // char(10) "Hello" → Windows-1252 "Hello" (5 bytes, unpadded)
1098        assert_eq!(
1099            normalize_for_encryption(
1100                &SqlValue::String("Hello".to_string()),
1101                Some(E::Char { length: 10 })
1102            )
1103            .unwrap(),
1104            unhex("48656c6c6f"),
1105        );
1106        // nchar(10) "Hello" → UTF-16LE (10 bytes, unpadded)
1107        assert_eq!(
1108            normalize_for_encryption(
1109                &SqlValue::String("Hello".to_string()),
1110                Some(E::NChar { length: 10 })
1111            )
1112            .unwrap(),
1113            unhex("480065006c006c006f00"),
1114        );
1115        // binary(10) [1,2,3,4,5] → raw (5 bytes, unpadded)
1116        assert_eq!(
1117            normalize_for_encryption(
1118                &SqlValue::Binary(bytes::Bytes::from_static(&[1, 2, 3, 4, 5])),
1119                Some(E::Binary { length: 10 })
1120            )
1121            .unwrap(),
1122            unhex("0102030405"),
1123        );
1124    }
1125
1126    /// A `char` value carrying a code point absent from Windows-1252 must error,
1127    /// not silently encode to numeric-character-reference garbage. `encoding_rs`
1128    /// substitutes `&#NNNN;` (and sets `had_errors`) for unmappable code points;
1129    /// discarding that flag produced ciphertext no Microsoft client could read
1130    /// (.NET stores `?` for the same input). Erroring converts silent corruption
1131    /// into a clear failure.
1132    #[cfg(feature = "always-encrypted")]
1133    #[test]
1134    fn ae_char_rejects_non_windows_1252() {
1135        use mssql_types::EncryptedParamType as E;
1136        let r = normalize_for_encryption(
1137            &SqlValue::String("中".to_string()),
1138            Some(E::Char { length: 10 }),
1139        );
1140        assert!(
1141            r.is_err(),
1142            "non-Windows-1252 char must error, got {:?}",
1143            r.map(|b| b.iter().map(|x| format!("{x:02x}")).collect::<String>())
1144        );
1145        // A value fully representable in Windows-1252 still normalizes (é = 0xE9).
1146        assert_eq!(
1147            normalize_for_encryption(
1148                &SqlValue::String("é".to_string()),
1149                Some(E::Char { length: 10 })
1150            )
1151            .unwrap(),
1152            vec![0xE9],
1153        );
1154    }
1155
1156    #[test]
1157    fn test_encryption_config_defaults() {
1158        let config = EncryptionConfig::new();
1159        assert!(config.enabled);
1160        assert!(config.cache_ceks);
1161        assert!(!config.is_ready()); // No providers
1162    }
1163
1164    /// Parse synthetic `sp_describe_parameter_encryption` result sets that mirror
1165    /// the live wire shape (captured in `.tmp/ae-3a2-describe-schema.md`). The
1166    /// column *order* is validated separately by the live test; this exercises
1167    /// the logic the live single-CEK/single-CMK case cannot: grouping multiple
1168    /// CMK-wrappings under one CEK, translating the server's (1-based) key
1169    /// ordinal to a positional index, little-endian `binary(8)` md-version
1170    /// decode, and skipping plaintext parameters.
1171    #[cfg(feature = "always-encrypted")]
1172    #[test]
1173    fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1174        use crate::row::{Column, Row};
1175        use crate::stream::ResultSet;
1176        use bytes::Bytes;
1177
1178        fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1179            let cols: Vec<Column> = (0..n_cols)
1180                .map(|i| Column::new(format!("c{i}"), i, "x"))
1181                .collect();
1182            let rows = rows
1183                .into_iter()
1184                .map(|vals| Row::from_values(cols.clone(), vals))
1185                .collect();
1186            ResultSet::new(cols, rows)
1187        }
1188
1189        let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); // -> 1
1190        let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); // -> 255
1191
1192        // RS1: CEK ordinal 1 wrapped by two CMKs (rotation), plus CEK ordinal 2.
1193        let rs1 = rs(
1194            9,
1195            vec![
1196                vec![
1197                    SqlValue::Int(1),
1198                    SqlValue::Int(7),
1199                    SqlValue::Int(56),
1200                    SqlValue::Int(1),
1201                    SqlValue::Binary(mdv1.clone()),
1202                    SqlValue::Binary(Bytes::from_static(b"env-a")),
1203                    SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1204                    SqlValue::String("path-a".into()),
1205                    SqlValue::String("RSA_OAEP".into()),
1206                ],
1207                vec![
1208                    SqlValue::Int(1),
1209                    SqlValue::Int(7),
1210                    SqlValue::Int(56),
1211                    SqlValue::Int(1),
1212                    SqlValue::Binary(mdv1),
1213                    SqlValue::Binary(Bytes::from_static(b"env-a2")),
1214                    SqlValue::String("PROV_2".into()),
1215                    SqlValue::String("path-a2".into()),
1216                    SqlValue::String("RSA_OAEP".into()),
1217                ],
1218                vec![
1219                    SqlValue::Int(2),
1220                    SqlValue::Int(7),
1221                    SqlValue::Int(57),
1222                    SqlValue::Int(1),
1223                    SqlValue::Binary(mdv2),
1224                    SqlValue::Binary(Bytes::from_static(b"env-b")),
1225                    SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1226                    SqlValue::String("path-b".into()),
1227                    SqlValue::String("RSA_OAEP".into()),
1228                ],
1229            ],
1230        );
1231
1232        // RS2: @det on CEK ordinal 1, @rand on CEK ordinal 2, @plain plaintext.
1233        let rs2 = rs(
1234            6,
1235            vec![
1236                vec![
1237                    SqlValue::Int(1),
1238                    SqlValue::String("@det".into()),
1239                    SqlValue::TinyInt(2),
1240                    SqlValue::TinyInt(1),
1241                    SqlValue::Int(1),
1242                    SqlValue::TinyInt(1),
1243                ],
1244                vec![
1245                    SqlValue::Int(2),
1246                    SqlValue::String("@rand".into()),
1247                    SqlValue::TinyInt(2),
1248                    SqlValue::TinyInt(2),
1249                    SqlValue::Int(2),
1250                    SqlValue::TinyInt(1),
1251                ],
1252                vec![
1253                    SqlValue::Int(3),
1254                    SqlValue::String("@plain".into()),
1255                    SqlValue::TinyInt(0),
1256                    SqlValue::TinyInt(0),
1257                    SqlValue::Int(0),
1258                    SqlValue::TinyInt(0),
1259                ],
1260            ],
1261        );
1262
1263        let mut sets = vec![rs1, rs2];
1264        let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1265
1266        assert_eq!(info.cek_table.len(), 2);
1267        let e0 = info.cek_table.get(0).unwrap();
1268        assert_eq!(e0.cek_id, 56);
1269        assert_eq!(e0.cek_md_version, 1);
1270        assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1271        assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1272        assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1273        let e1 = info.cek_table.get(1).unwrap();
1274        assert_eq!(e1.cek_id, 57);
1275        assert_eq!(e1.cek_md_version, 255);
1276
1277        let det = info.get_parameter("@det").unwrap();
1278        assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1279        assert_eq!(det.algorithm_id, 2);
1280        assert_eq!(det.normalization_rule_version, 1);
1281        assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1282
1283        let rand = info.get_parameter("@rand").unwrap();
1284        assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1285        assert_eq!(
1286            rand.cek_ordinal, 1,
1287            "server ordinal 2 -> positional index 1"
1288        );
1289
1290        assert!(info.get_parameter("@plain").is_none());
1291        assert_eq!(info.parameters.len(), 2);
1292    }
1293
1294    /// A truncated response (fewer than two result sets) must be rejected, not
1295    /// silently treated as "no parameters need encryption".
1296    #[cfg(feature = "always-encrypted")]
1297    #[test]
1298    fn parse_describe_result_sets_rejects_missing_result_set() {
1299        use crate::row::{Column, Row};
1300        use crate::stream::ResultSet;
1301
1302        let cols: Vec<Column> = (0..9)
1303            .map(|i| Column::new(format!("c{i}"), i, "x"))
1304            .collect();
1305        let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1306        assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1307    }
1308}