Skip to main content

tds_protocol/
crypto.rs

1//! Always Encrypted cryptography metadata for TDS protocol.
2//!
3//! This module defines the wire-level structures for SQL Server's Always Encrypted
4//! feature. When a query returns encrypted columns, SQL Server sends additional
5//! metadata describing how to decrypt the data.
6//!
7//! ## TDS Wire Format
8//!
9//! When Always Encrypted is enabled, the COLMETADATA token includes:
10//!
11//! 1. **CEK Table**: A table of Column Encryption Keys needed for the result set
12//! 2. **CryptoMetadata**: Per-column encryption information
13//!
14//! ```text
15//! COLMETADATA Token (with encryption):
16//! ┌─────────────────────────────────────────────────────────────────┐
17//! │ Column Count (2 bytes)                                          │
18//! ├─────────────────────────────────────────────────────────────────┤
19//! │ CEK Table (if encrypted columns present)                        │
20//! │ ├── CEK Count (2 bytes)                                         │
21//! │ ├── CEK Entry 1                                                 │
22//! │ │   ├── Database ID (4 bytes)                                   │
23//! │ │   ├── CEK ID (4 bytes)                                        │
24//! │ │   ├── CEK Version (4 bytes)                                   │
25//! │ │   ├── CEK MD Version (8 bytes)                                │
26//! │ │   ├── CEK Value Count (1 byte)                                │
27//! │ │   └── CEK Value(s)                                            │
28//! │ │       ├── Encrypted Value Length (2 bytes)                    │
29//! │ │       ├── Encrypted Value (variable)                          │
30//! │ │       ├── Key Store Name (B_VARCHAR)                          │
31//! │ │       ├── CMK Path (US_VARCHAR)                               │
32//! │ │       └── Algorithm (B_VARCHAR)                               │
33//! │ └── ...more CEK entries                                         │
34//! ├─────────────────────────────────────────────────────────────────┤
35//! │ Column Definitions                                              │
36//! │ ├── Column 1                                                    │
37//! │ │   ├── User Type (4 bytes)                                     │
38//! │ │   ├── Flags (2 bytes) - includes encryption flag              │
39//! │ │   ├── Type ID (1 byte)                                        │
40//! │ │   ├── Type Info (variable)                                    │
41//! │ │   ├── CryptoMetadata (if encrypted)                           │
42//! │ │   │   ├── CEK Table Ordinal (2 bytes)                         │
43//! │ │   │   ├── Algorithm ID (1 byte)                               │
44//! │ │   │   ├── Encryption Type (1 byte)                            │
45//! │ │   │   └── Normalization Version (1 byte)                      │
46//! │ │   └── Column Name (B_VARCHAR)                                 │
47//! │ └── ...more columns                                             │
48//! └─────────────────────────────────────────────────────────────────┘
49//! ```
50
51use bytes::{Buf, Bytes};
52
53use crate::codec::{read_b_varchar, read_us_varchar};
54use crate::error::ProtocolError;
55use crate::prelude::*;
56
57/// Column flags bit indicating the column is encrypted.
58pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
59
60/// Algorithm ID for AEAD_AES_256_CBC_HMAC_SHA256.
61pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
62
63/// Encryption type: Deterministic.
64pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
65
66/// Encryption type: Randomized.
67pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
68
69/// Current normalization rule version.
70pub const NORMALIZATION_RULE_VERSION: u8 = 1;
71
72/// Column Encryption Key table entry.
73///
74/// This represents a single CEK entry in the CEK table sent with COLMETADATA.
75/// Multiple columns may share the same CEK.
76#[derive(Debug, Clone)]
77pub struct CekTableEntry {
78    /// Database ID where the CEK is defined.
79    pub database_id: u32,
80    /// CEK ID within the database.
81    pub cek_id: u32,
82    /// CEK version (incremented on key rotation).
83    pub cek_version: u32,
84    /// Metadata version (changes with any metadata update).
85    pub cek_md_version: u64,
86    /// CEK value entries (usually one, but may have multiple for key rotation).
87    pub values: Vec<CekValue>,
88}
89
90/// A single CEK value (encrypted by CMK).
91///
92/// A CEK may have multiple values when key rotation is in progress,
93/// with different CMKs encrypting the same CEK.
94#[derive(Debug, Clone)]
95pub struct CekValue {
96    /// The encrypted CEK bytes.
97    pub encrypted_value: Bytes,
98    /// Name of the key store provider (e.g., "AZURE_KEY_VAULT").
99    pub key_store_provider_name: String,
100    /// Path to the Column Master Key in the key store.
101    pub cmk_path: String,
102    /// Asymmetric algorithm used to encrypt the CEK (e.g., "RSA_OAEP").
103    pub encryption_algorithm: String,
104}
105
106/// Per-column encryption metadata.
107///
108/// This metadata is present for each encrypted column and describes
109/// how to decrypt the column data. Per MS-TDS 2.2.7.4, the wire format is:
110///
111/// ```text
112/// ordinal(2) + base_user_type(4) + base_col_type(1) + base_type_info(var) + algo_id(1) + enc_type(1) + norm_ver(1)
113/// ```
114///
115/// The `base_col_type` and `base_type_info` describe the plaintext column type.
116/// The outer column metadata describes the ciphertext transport type (always BigVarBinary).
117#[derive(Debug, Clone)]
118pub struct CryptoMetadata {
119    /// Index into the CEK table (0-based).
120    pub cek_table_ordinal: u16,
121    /// Base user type of the plaintext column.
122    pub base_user_type: u32,
123    /// Base column type byte of the plaintext column.
124    pub base_col_type: u8,
125    /// Type-specific metadata for the plaintext column type.
126    pub base_type_info: crate::token::TypeInfo,
127    /// Encryption algorithm ID.
128    pub algorithm_id: u8,
129    /// Encryption type (deterministic or randomized).
130    pub encryption_type: EncryptionTypeWire,
131    /// Normalization rule version.
132    pub normalization_version: u8,
133}
134
135/// Wire-level encryption type.
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137#[non_exhaustive]
138pub enum EncryptionTypeWire {
139    /// Deterministic encryption (value 1).
140    Deterministic,
141    /// Randomized encryption (value 2).
142    Randomized,
143}
144
145impl EncryptionTypeWire {
146    /// Create from wire value.
147    #[must_use]
148    pub fn from_u8(value: u8) -> Option<Self> {
149        match value {
150            ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
151            ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
152            _ => None,
153        }
154    }
155
156    /// Convert to wire value.
157    #[must_use]
158    pub fn to_u8(self) -> u8 {
159        match self {
160            Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
161            Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
162        }
163    }
164}
165
166/// CEK table containing all Column Encryption Keys needed for a result set.
167#[derive(Debug, Clone, Default)]
168pub struct CekTable {
169    /// CEK entries.
170    pub entries: Vec<CekTableEntry>,
171}
172
173impl CekTable {
174    /// Create an empty CEK table.
175    #[must_use]
176    pub fn new() -> Self {
177        Self::default()
178    }
179
180    /// Get a CEK entry by ordinal.
181    #[must_use]
182    pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
183        self.entries.get(ordinal as usize)
184    }
185
186    /// Check if the table is empty.
187    #[must_use]
188    pub fn is_empty(&self) -> bool {
189        self.entries.is_empty()
190    }
191
192    /// Get the number of entries.
193    #[must_use]
194    pub fn len(&self) -> usize {
195        self.entries.len()
196    }
197
198    /// Decode a CEK table from the wire format.
199    ///
200    /// # Wire Format
201    ///
202    /// ```text
203    /// CEK_TABLE:
204    ///   cek_count: USHORT (2 bytes)
205    ///   entries: CEK_ENTRY[cek_count]
206    ///
207    /// CEK_ENTRY:
208    ///   database_id: DWORD (4 bytes)
209    ///   cek_id: DWORD (4 bytes)
210    ///   cek_version: DWORD (4 bytes)
211    ///   cek_md_version: ULONGLONG (8 bytes)
212    ///   value_count: BYTE (1 byte)
213    ///   values: CEK_VALUE[value_count]
214    ///
215    /// CEK_VALUE:
216    ///   encrypted_value_length: USHORT (2 bytes)
217    ///   encrypted_value: BYTE[encrypted_value_length]
218    ///   key_store_name: B_VARCHAR
219    ///   cmk_path: US_VARCHAR
220    ///   algorithm: B_VARCHAR
221    /// ```
222    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
223        if src.remaining() < 2 {
224            return Err(ProtocolError::UnexpectedEof);
225        }
226
227        let cek_count = src.get_u16_le() as usize;
228
229        let mut entries = Vec::with_capacity(cek_count);
230
231        for _ in 0..cek_count {
232            let entry = CekTableEntry::decode(src)?;
233            entries.push(entry);
234        }
235
236        Ok(Self { entries })
237    }
238}
239
240impl CekTableEntry {
241    /// Decode a CEK table entry from the wire format.
242    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
243        // database_id (4) + cek_id (4) + cek_version (4) + cek_md_version (8) + value_count (1)
244        if src.remaining() < 21 {
245            return Err(ProtocolError::UnexpectedEof);
246        }
247
248        let database_id = src.get_u32_le();
249        let cek_id = src.get_u32_le();
250        let cek_version = src.get_u32_le();
251        let cek_md_version = src.get_u64_le();
252        let value_count = src.get_u8() as usize;
253
254        let mut values = Vec::with_capacity(value_count);
255
256        for _ in 0..value_count {
257            let value = CekValue::decode(src)?;
258            values.push(value);
259        }
260
261        Ok(Self {
262            database_id,
263            cek_id,
264            cek_version,
265            cek_md_version,
266            values,
267        })
268    }
269
270    /// Get the first (primary) encrypted value.
271    #[must_use]
272    pub fn primary_value(&self) -> Option<&CekValue> {
273        self.values.first()
274    }
275}
276
277impl CekValue {
278    /// Decode a CEK value from the wire format.
279    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
280        // encrypted_value_length (2 bytes)
281        if src.remaining() < 2 {
282            return Err(ProtocolError::UnexpectedEof);
283        }
284
285        let encrypted_value_length = src.get_u16_le() as usize;
286
287        if src.remaining() < encrypted_value_length {
288            return Err(ProtocolError::UnexpectedEof);
289        }
290
291        let encrypted_value = src.copy_to_bytes(encrypted_value_length);
292
293        // key_store_name (B_VARCHAR)
294        let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
295
296        // cmk_path (US_VARCHAR)
297        let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
298
299        // algorithm (B_VARCHAR)
300        let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
301
302        Ok(Self {
303            encrypted_value,
304            key_store_provider_name,
305            cmk_path,
306            encryption_algorithm,
307        })
308    }
309}
310
311impl CryptoMetadata {
312    /// Decode crypto metadata from the wire format.
313    ///
314    /// Per MS-TDS 2.2.7.4, the wire format is:
315    /// ```text
316    /// cek_table_ordinal: USHORT (2 bytes)
317    /// base_user_type: ULONG (4 bytes)
318    /// base_col_type: BYTE (1 byte)
319    /// base_type_info: TYPE_INFO (variable, depends on base_col_type)
320    /// algorithm_id: BYTE (1 byte)
321    /// encryption_type: BYTE (1 byte)
322    /// normalization_version: BYTE (1 byte)
323    /// ```
324    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
325        // ordinal (2) + base_user_type (4) + base_col_type (1) = 7 minimum
326        if src.remaining() < 7 {
327            return Err(ProtocolError::UnexpectedEof);
328        }
329
330        let cek_table_ordinal = src.get_u16_le();
331        let base_user_type = src.get_u32_le();
332        let base_col_type = src.get_u8();
333
334        // Parse base type info using the shared decoder from token.rs
335        let base_type_id =
336            crate::types::TypeId::from_u8(base_col_type).unwrap_or(crate::types::TypeId::Null);
337        let base_type_info = crate::token::decode_type_info(src, base_type_id, base_col_type)?;
338
339        // algorithm_id (1) + encryption_type (1) + normalization_version (1) = 3
340        if src.remaining() < 3 {
341            return Err(ProtocolError::UnexpectedEof);
342        }
343
344        let algorithm_id = src.get_u8();
345        let encryption_type_byte = src.get_u8();
346        let normalization_version = src.get_u8();
347
348        let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
349            ProtocolError::InvalidField {
350                field: "encryption_type",
351                value: encryption_type_byte as u32,
352            },
353        )?;
354
355        Ok(Self {
356            cek_table_ordinal,
357            base_user_type,
358            base_col_type,
359            base_type_info,
360            algorithm_id,
361            encryption_type,
362            normalization_version,
363        })
364    }
365
366    /// Check if this uses the standard AEAD algorithm.
367    #[must_use]
368    pub fn is_aead_aes_256(&self) -> bool {
369        self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
370    }
371
372    /// Check if this uses deterministic encryption.
373    #[must_use]
374    pub fn is_deterministic(&self) -> bool {
375        self.encryption_type == EncryptionTypeWire::Deterministic
376    }
377
378    /// Check if this uses randomized encryption.
379    #[must_use]
380    pub fn is_randomized(&self) -> bool {
381        self.encryption_type == EncryptionTypeWire::Randomized
382    }
383
384    /// Get the TypeId of the base (plaintext) column type.
385    #[must_use]
386    pub fn base_type_id(&self) -> crate::types::TypeId {
387        crate::types::TypeId::from_u8(self.base_col_type).unwrap_or(crate::types::TypeId::Null)
388    }
389}
390
391/// Extended column metadata with encryption information.
392///
393/// This combines the base column metadata with optional crypto metadata
394/// for Always Encrypted columns.
395#[derive(Debug, Clone, Default)]
396pub struct ColumnCryptoInfo {
397    /// Crypto metadata (if column is encrypted).
398    pub crypto_metadata: Option<CryptoMetadata>,
399}
400
401impl ColumnCryptoInfo {
402    /// Create info for an unencrypted column.
403    #[must_use]
404    pub fn unencrypted() -> Self {
405        Self {
406            crypto_metadata: None,
407        }
408    }
409
410    /// Create info for an encrypted column.
411    #[must_use]
412    pub fn encrypted(metadata: CryptoMetadata) -> Self {
413        Self {
414            crypto_metadata: Some(metadata),
415        }
416    }
417
418    /// Check if this column is encrypted.
419    #[must_use]
420    pub fn is_encrypted(&self) -> bool {
421        self.crypto_metadata.is_some()
422    }
423}
424
425/// Check if a column flags value indicates encryption.
426#[must_use]
427pub fn is_column_encrypted(flags: u16) -> bool {
428    (flags & COLUMN_FLAG_ENCRYPTED) != 0
429}
430
431#[cfg(test)]
432#[allow(clippy::unwrap_used, clippy::expect_used)]
433mod tests {
434    use super::*;
435    use bytes::BytesMut;
436
437    #[test]
438    fn test_encryption_type_wire_roundtrip() {
439        assert_eq!(
440            EncryptionTypeWire::from_u8(1),
441            Some(EncryptionTypeWire::Deterministic)
442        );
443        assert_eq!(
444            EncryptionTypeWire::from_u8(2),
445            Some(EncryptionTypeWire::Randomized)
446        );
447        assert_eq!(EncryptionTypeWire::from_u8(0), None);
448        assert_eq!(EncryptionTypeWire::from_u8(99), None);
449
450        assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
451        assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
452    }
453
454    #[test]
455    fn test_crypto_metadata_decode() {
456        let data = [
457            0x00, 0x00, // cek_table_ordinal = 0
458            0x00, 0x00, 0x00, 0x00, // base_user_type = 0
459            0x26, // base_col_type = IntN (0x26)
460            0x04, // base_type_info: IntN max_length = 4 (INT)
461            0x02, // algorithm_id = AEAD_AES_256_CBC_HMAC_SHA256
462            0x01, // encryption_type = Deterministic
463            0x01, // normalization_version = 1
464        ];
465
466        let mut cursor: &[u8] = &data;
467        let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
468
469        assert_eq!(metadata.cek_table_ordinal, 0);
470        assert_eq!(metadata.base_user_type, 0);
471        assert_eq!(metadata.base_col_type, 0x26); // IntN
472        assert_eq!(metadata.base_type_info.max_length, Some(4));
473        assert_eq!(
474            metadata.algorithm_id,
475            ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
476        );
477        assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
478        assert_eq!(metadata.normalization_version, 1);
479        assert!(metadata.is_aead_aes_256());
480        assert!(metadata.is_deterministic());
481        assert!(!metadata.is_randomized());
482
483        // Test base_type_id helper
484        assert_eq!(metadata.base_type_id(), crate::types::TypeId::IntN);
485    }
486
487    #[test]
488    fn test_cek_value_decode() {
489        let mut data = BytesMut::new();
490
491        // encrypted_value_length = 4
492        data.extend_from_slice(&[0x04, 0x00]);
493        // encrypted_value = [0xDE, 0xAD, 0xBE, 0xEF]
494        data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
495        // key_store_name = "TEST" (B_VARCHAR: 1 byte len + utf16le)
496        data.extend_from_slice(&[0x04]); // 4 chars
497        data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
498        // cmk_path = "key1" (US_VARCHAR: 2 byte len + utf16le)
499        data.extend_from_slice(&[0x04, 0x00]); // 4 chars
500        data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
501        // algorithm = "RSA" (B_VARCHAR)
502        data.extend_from_slice(&[0x03]); // 3 chars
503        data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
504
505        let mut cursor: &[u8] = &data;
506        let value = CekValue::decode(&mut cursor).unwrap();
507
508        assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
509        assert_eq!(value.key_store_provider_name, "TEST");
510        assert_eq!(value.cmk_path, "key1");
511        assert_eq!(value.encryption_algorithm, "RSA");
512    }
513
514    #[test]
515    fn test_cek_table_entry_decode() {
516        let mut data = BytesMut::new();
517
518        // database_id = 1
519        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
520        // cek_id = 2
521        data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
522        // cek_version = 1
523        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
524        // cek_md_version = 100
525        data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
526        // value_count = 1
527        data.extend_from_slice(&[0x01]);
528
529        // CEK value
530        data.extend_from_slice(&[0x04, 0x00]); // encrypted_value_length = 4
531        data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); // encrypted_value
532        data.extend_from_slice(&[0x02]); // key_store_name length = 2
533        data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); // "KS"
534        data.extend_from_slice(&[0x01, 0x00]); // cmk_path length = 1
535        data.extend_from_slice(&[b'P', 0x00]); // "P"
536        data.extend_from_slice(&[0x01]); // algorithm length = 1
537        data.extend_from_slice(&[b'A', 0x00]); // "A"
538
539        let mut cursor: &[u8] = &data;
540        let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
541
542        assert_eq!(entry.database_id, 1);
543        assert_eq!(entry.cek_id, 2);
544        assert_eq!(entry.cek_version, 1);
545        assert_eq!(entry.cek_md_version, 100);
546        assert_eq!(entry.values.len(), 1);
547
548        let value = entry.primary_value().expect("should have primary value");
549        assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
550    }
551
552    #[test]
553    fn test_cek_table_decode() {
554        let mut data = BytesMut::new();
555
556        // cek_count = 1
557        data.extend_from_slice(&[0x01, 0x00]);
558
559        // CEK entry
560        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // database_id
561        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // cek_id
562        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // cek_version
563        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); // cek_md_version
564        data.extend_from_slice(&[0x01]); // value_count
565
566        // CEK value
567        data.extend_from_slice(&[0x02, 0x00]); // encrypted_value_length = 2
568        data.extend_from_slice(&[0xAB, 0xCD]); // encrypted_value
569        data.extend_from_slice(&[0x01]); // key_store_name = "K"
570        data.extend_from_slice(&[b'K', 0x00]);
571        data.extend_from_slice(&[0x01, 0x00]); // cmk_path = "P"
572        data.extend_from_slice(&[b'P', 0x00]);
573        data.extend_from_slice(&[0x01]); // algorithm = "A"
574        data.extend_from_slice(&[b'A', 0x00]);
575
576        let mut cursor: &[u8] = &data;
577        let table = CekTable::decode(&mut cursor).expect("should decode table");
578
579        assert_eq!(table.len(), 1);
580        assert!(!table.is_empty());
581
582        let entry = table.get(0).expect("should have first entry");
583        assert_eq!(entry.database_id, 1);
584    }
585
586    #[test]
587    fn test_is_column_encrypted() {
588        assert!(!is_column_encrypted(0x0000));
589        assert!(!is_column_encrypted(0x0001)); // nullable
590        assert!(is_column_encrypted(0x0800)); // encrypted flag
591        assert!(is_column_encrypted(0x0801)); // encrypted + nullable
592    }
593
594    #[test]
595    fn test_column_crypto_info() {
596        let unencrypted = ColumnCryptoInfo::unencrypted();
597        assert!(!unencrypted.is_encrypted());
598
599        let metadata = CryptoMetadata {
600            cek_table_ordinal: 0,
601            base_user_type: 0,
602            base_col_type: 0x26, // IntN
603            base_type_info: crate::token::TypeInfo::default(),
604            algorithm_id: 2,
605            encryption_type: EncryptionTypeWire::Randomized,
606            normalization_version: 1,
607        };
608        let encrypted = ColumnCryptoInfo::encrypted(metadata);
609        assert!(encrypted.is_encrypted());
610    }
611}