1use std::collections::HashMap;
79
80use mssql_auth::KeyStoreProvider;
81use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
82
83#[cfg(feature = "always-encrypted")]
84use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
85#[cfg(feature = "always-encrypted")]
86use mssql_types::SqlValue;
87#[cfg(feature = "always-encrypted")]
88use std::sync::Arc;
89
90#[cfg(feature = "always-encrypted")]
91use crate::{Error, row::Row, stream::ResultSet};
92#[cfg(feature = "always-encrypted")]
93use tds_protocol::crypto::CekValue;
94
95#[derive(Default)]
97pub struct EncryptionConfig {
98 pub enabled: bool,
100 providers: Vec<Box<dyn KeyStoreProvider>>,
102 pub cache_ceks: bool,
104}
105
106impl EncryptionConfig {
107 #[must_use]
109 pub fn new() -> Self {
110 Self {
111 enabled: true,
112 providers: Vec::new(),
113 cache_ceks: true,
114 }
115 }
116
117 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
119 self.providers.push(Box::new(provider));
120 }
121
122 #[must_use]
124 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
125 self.register_provider(provider);
126 self
127 }
128
129 #[must_use]
131 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
132 self.cache_ceks = enabled;
133 self
134 }
135
136 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
138 self.providers
139 .iter()
140 .find(|p| p.provider_name() == name)
141 .map(|p| p.as_ref())
142 }
143
144 #[must_use]
146 pub fn is_ready(&self) -> bool {
147 self.enabled && !self.providers.is_empty()
148 }
149}
150
151impl std::fmt::Debug for EncryptionConfig {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("EncryptionConfig")
154 .field("enabled", &self.enabled)
155 .field("provider_count", &self.providers.len())
156 .field("cache_ceks", &self.cache_ceks)
157 .finish()
158 }
159}
160
161#[cfg(feature = "always-encrypted")]
170pub struct EncryptionContext {
171 config: std::sync::Arc<EncryptionConfig>,
175 cek_cache: CekCache,
177 cache_enabled: bool,
179}
180
181#[cfg(feature = "always-encrypted")]
182impl EncryptionContext {
183 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
189 let cache_enabled = config.cache_ceks;
190 Self {
191 config,
192 cek_cache: CekCache::new(),
193 cache_enabled,
194 }
195 }
196
197 pub fn new(config: EncryptionConfig) -> Self {
199 Self::from_arc(std::sync::Arc::new(config))
200 }
201
202 pub async fn get_encryptor(
209 &self,
210 cek_entry: &CekTableEntry,
211 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
212 let cache_key = CekCacheKey::new(
213 cek_entry.database_id,
214 cek_entry.cek_id,
215 cek_entry.cek_version,
216 );
217
218 if self.cache_enabled {
220 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
221 return Ok(encryptor);
222 }
223 }
224
225 let cek_value = cek_entry
227 .primary_value()
228 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
229
230 let provider = self
232 .config
233 .get_provider(&cek_value.key_store_provider_name)
234 .ok_or_else(|| {
235 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
236 })?;
237
238 let decrypted_cek = provider
240 .decrypt_cek(
241 &cek_value.cmk_path,
242 &cek_value.encryption_algorithm,
243 &cek_value.encrypted_value,
244 )
245 .await?;
246
247 if self.cache_enabled {
249 self.cek_cache.insert(cache_key, decrypted_cek)
250 } else {
251 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
253 }
254 }
255
256 pub async fn encrypt_value(
264 &self,
265 plaintext: &[u8],
266 cek_entry: &CekTableEntry,
267 encryption_type: EncryptionTypeWire,
268 ) -> Result<Vec<u8>, EncryptionError> {
269 let encryptor = self.get_encryptor(cek_entry).await?;
270
271 let enc_type = match encryption_type {
272 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
273 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
274 _ => {
275 return Err(EncryptionError::UnsupportedOperation(format!(
276 "unsupported encryption type: {encryption_type:?}"
277 )));
278 }
279 };
280
281 encryptor.encrypt(plaintext, enc_type)
282 }
283
284 pub async fn decrypt_value(
291 &self,
292 ciphertext: &[u8],
293 cek_entry: &CekTableEntry,
294 ) -> Result<Vec<u8>, EncryptionError> {
295 let encryptor = self.get_encryptor(cek_entry).await?;
296 encryptor.decrypt(ciphertext)
297 }
298
299 pub fn clear_cache(&self) {
303 self.cek_cache.clear();
304 }
305
306 pub fn has_provider(&self, name: &str) -> bool {
308 self.config.get_provider(name).is_some()
309 }
310}
311
312#[cfg(feature = "always-encrypted")]
313impl std::fmt::Debug for EncryptionContext {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("EncryptionContext")
316 .field("provider_count", &self.config.providers.len())
317 .field("cache_entries", &self.cek_cache.len())
318 .field("cache_enabled", &self.cache_enabled)
319 .finish()
320 }
321}
322
323#[derive(Debug, Clone)]
328pub struct ResultSetEncryptionInfo {
329 pub cek_table: CekTable,
331 pub column_crypto: Vec<Option<CryptoMetadata>>,
333}
334
335impl ResultSetEncryptionInfo {
336 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
338 Self {
339 cek_table,
340 column_crypto: vec![None; column_count],
341 }
342 }
343
344 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
346 if ordinal < self.column_crypto.len() {
347 self.column_crypto[ordinal] = Some(metadata);
348 }
349 }
350
351 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
353 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
354 self.cek_table.get(crypto.cek_table_ordinal)
355 }
356
357 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
359 self.column_crypto
360 .get(ordinal)
361 .map(|c| c.is_some())
362 .unwrap_or(false)
363 }
364
365 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
367 self.column_crypto
368 .get(ordinal)?
369 .as_ref()
370 .map(|c| c.encryption_type)
371 }
372}
373
374#[derive(Debug, Clone)]
379pub struct ParameterEncryptionInfo {
380 pub cek_table: CekTable,
382 pub parameters: HashMap<String, ParameterCryptoInfo>,
384}
385
386impl ParameterEncryptionInfo {
387 pub fn new() -> Self {
389 Self {
390 cek_table: CekTable::new(),
391 parameters: HashMap::new(),
392 }
393 }
394
395 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
397 self.parameters.insert(name, info);
398 }
399
400 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
402 self.parameters.get(name)
403 }
404
405 pub fn needs_encryption(&self, name: &str) -> bool {
407 self.parameters.contains_key(name)
408 }
409}
410
411impl Default for ParameterEncryptionInfo {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417#[derive(Debug, Clone)]
420pub struct ParameterCryptoInfo {
421 pub cek_ordinal: u16,
427 pub encryption_type: EncryptionTypeWire,
429 pub algorithm_id: u8,
431 pub normalization_rule_version: u8,
433}
434
435impl ParameterCryptoInfo {
436 pub fn new(
438 cek_ordinal: u16,
439 encryption_type: EncryptionTypeWire,
440 algorithm_id: u8,
441 normalization_rule_version: u8,
442 ) -> Self {
443 Self {
444 cek_ordinal,
445 encryption_type,
446 algorithm_id,
447 normalization_rule_version,
448 }
449 }
450}
451
452#[cfg(feature = "always-encrypted")]
462impl ParameterEncryptionInfo {
463 const RS1_MIN_COLS: usize = 9;
466 const RS2_MIN_COLS: usize = 6;
468
469 pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
474 if result_sets.len() < 2 {
475 return Err(Error::Protocol(format!(
476 "sp_describe_parameter_encryption returned {} result set(s), expected 2",
477 result_sets.len()
478 )));
479 }
480
481 let rs1_cols = result_sets[0].columns().len();
483 if rs1_cols < Self::RS1_MIN_COLS {
484 return Err(Error::Protocol(format!(
485 "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
486 Self::RS1_MIN_COLS
487 )));
488 }
489 let rs1_rows = result_sets[0].collect_all()?;
490
491 let mut entries: Vec<CekTableEntry> = Vec::new();
492 let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
494
495 for row in &rs1_rows {
496 let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
497 let value = CekValue {
498 encrypted_value: describe_varbinary(
499 row,
500 5,
501 "column_encryption_key_encrypted_value",
502 )?,
503 key_store_provider_name: describe_nvarchar(
504 row,
505 6,
506 "column_master_key_store_provider_name",
507 )?,
508 cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
509 encryption_algorithm: describe_nvarchar(
510 row,
511 8,
512 "column_encryption_key_encryption_algorithm_name",
513 )?,
514 };
515
516 if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
517 entries[idx as usize].values.push(value);
519 } else {
520 let idx = u16::try_from(entries.len()).map_err(|_| {
521 Error::Protocol(
522 "sp_describe_parameter_encryption returned too many CEKs".into(),
523 )
524 })?;
525 ordinal_to_index.insert(key_ordinal, idx);
526 entries.push(CekTableEntry {
527 database_id: describe_int(row, 1, "database_id")? as u32,
528 cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
529 cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
530 cek_md_version: describe_md_version(row, 4)?,
531 values: vec![value],
532 });
533 }
534 }
535 let cek_table = CekTable { entries };
536
537 let rs2_cols = result_sets[1].columns().len();
539 if rs2_cols < Self::RS2_MIN_COLS {
540 return Err(Error::Protocol(format!(
541 "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
542 Self::RS2_MIN_COLS
543 )));
544 }
545 let rs2_rows = result_sets[1].collect_all()?;
546
547 let mut parameters = HashMap::new();
548 for row in &rs2_rows {
549 let name = describe_nvarchar(row, 1, "parameter_name")?;
550 let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
551 if encryption_type_byte == 0 {
553 continue;
554 }
555 let encryption_type =
556 EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
557 Error::Protocol(format!(
558 "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
559 ))
560 })?;
561 let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
562 let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
563 let normalization_rule_version =
564 describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
565
566 let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
567 Error::Protocol(format!(
568 "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
569 ))
570 })?;
571
572 parameters.insert(
573 name,
574 ParameterCryptoInfo {
575 cek_ordinal,
576 encryption_type,
577 algorithm_id,
578 normalization_rule_version,
579 },
580 );
581 }
582
583 Ok(Self {
584 cek_table,
585 parameters,
586 })
587 }
588}
589
590#[cfg(feature = "always-encrypted")]
592fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
593 match row.get_raw(idx) {
594 Some(SqlValue::Int(v)) => Ok(v),
595 other => Err(describe_type_error(col, idx, "int", other.as_ref())),
596 }
597}
598
599#[cfg(feature = "always-encrypted")]
601fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
602 match row.get_raw(idx) {
603 Some(SqlValue::TinyInt(v)) => Ok(v),
604 other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
605 }
606}
607
608#[cfg(feature = "always-encrypted")]
610fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
611 match row.get_raw(idx) {
612 Some(SqlValue::String(v)) => Ok(v),
613 other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
614 }
615}
616
617#[cfg(feature = "always-encrypted")]
619fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
620 match row.get_raw(idx) {
621 Some(SqlValue::Binary(v)) => Ok(v),
622 other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
623 }
624}
625
626#[cfg(feature = "always-encrypted")]
628fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
629 match row.get_raw(idx) {
630 Some(SqlValue::Binary(b)) if b.len() == 8 => {
631 let mut bytes = [0u8; 8];
632 bytes.copy_from_slice(&b[..8]);
633 Ok(u64::from_le_bytes(bytes))
634 }
635 other => Err(describe_type_error(
636 "column_encryption_key_metadata_version",
637 idx,
638 "binary(8)",
639 other.as_ref(),
640 )),
641 }
642}
643
644#[cfg(feature = "always-encrypted")]
646fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
647 let got = got.map_or("missing", SqlValue::type_name);
648 Error::Protocol(format!(
649 "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
650 ))
651}
652
653#[cfg(feature = "always-encrypted")]
663pub fn normalize_for_encryption(value: &SqlValue) -> Result<Vec<u8>, EncryptionError> {
664 match value {
665 SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
669 SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
670 SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
671 SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
672 SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
673 SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
675 SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
676 SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
678 SqlValue::Binary(b) => Ok(b.to_vec()),
680 #[cfg(feature = "uuid")]
683 SqlValue::Uuid(u) => {
684 let b = u.as_bytes();
685 Ok(vec![
686 b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
687 b[13], b[14], b[15],
688 ])
689 }
690 #[cfg(feature = "chrono")]
693 SqlValue::Date(d) => {
694 use chrono::Datelike;
695 let days = (d.num_days_from_ce() - 1) as u32;
696 Ok(days.to_le_bytes()[..3].to_vec())
697 }
698 #[cfg(feature = "decimal")]
701 SqlValue::Decimal(d) => {
702 let mut out = Vec::with_capacity(17);
703 out.push(u8::from(!d.is_sign_negative()));
704 out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
705 Ok(out)
706 }
707 #[cfg(feature = "decimal")]
710 SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
711 let cents = money_cents(d)?;
712 let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
713 out.extend_from_slice(&(cents as u32).to_le_bytes());
714 Ok(out)
715 }
716 other => Err(EncryptionError::UnsupportedOperation(format!(
717 "Always Encrypted parameter encryption is not yet implemented for {}",
718 other.type_name()
719 ))),
720 }
721}
722
723#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
726fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
727 let mantissa = value.mantissa();
728 let scale = value.scale();
729 let cents: i128 = if scale <= 4 {
730 mantissa
731 .checked_mul(10_i128.pow(4 - scale))
732 .ok_or_else(|| {
733 EncryptionError::UnsupportedOperation("MONEY value out of range".into())
734 })?
735 } else {
736 mantissa / 10_i128.pow(scale - 4)
737 };
738 i64::try_from(cents)
739 .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
740}
741
742#[cfg(test)]
743#[allow(clippy::unwrap_used, clippy::expect_used)]
744mod tests {
745 use super::*;
746
747 #[cfg(feature = "always-encrypted")]
753 #[test]
754 fn ae_normalization_matches_dotnet() {
755 use bytes::Bytes;
756
757 fn unhex(s: &str) -> Vec<u8> {
758 (0..s.len())
759 .step_by(2)
760 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
761 .collect()
762 }
763
764 let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
765 let enc = AeadEncryptor::new(&cek).unwrap();
766
767 for (value, reference) in [
768 (
769 SqlValue::Int(42),
770 "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
771 ),
772 (
773 SqlValue::String("Ada".to_string()),
774 "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
775 ),
776 (
777 SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
778 "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
779 ),
780 ] {
781 let norm = normalize_for_encryption(&value).unwrap();
782 let cipher = enc
783 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
784 .unwrap();
785 assert_eq!(
786 cipher,
787 unhex(reference),
788 "ciphertext for {} must match Microsoft.Data.SqlClient",
789 value.type_name()
790 );
791 }
792 }
793
794 #[cfg(feature = "always-encrypted")]
799 #[test]
800 fn ae_normalization_rejects_unnormalizable_value() {
801 assert!(normalize_for_encryption(&SqlValue::Null).is_err());
802 }
803
804 #[cfg(feature = "always-encrypted")]
810 #[test]
811 fn ae_normalization_matches_dotnet_numeric() {
812 fn unhex(s: &str) -> Vec<u8> {
813 (0..s.len())
814 .step_by(2)
815 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
816 .collect()
817 }
818
819 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
820 let enc = AeadEncryptor::new(&cek).unwrap();
821
822 for (value, reference) in [
823 (
824 SqlValue::BigInt(0x0102030405060708),
825 "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
826 ),
827 (
828 SqlValue::SmallInt(258),
829 "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
830 ),
831 (
832 SqlValue::TinyInt(200),
833 "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
834 ),
835 (
836 SqlValue::Bool(true),
837 "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
838 ),
839 (
840 SqlValue::Float(3.5),
841 "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
842 ),
843 (
844 SqlValue::Double(3.5),
845 "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
846 ),
847 ] {
848 let norm = normalize_for_encryption(&value).unwrap();
849 let cipher = enc
850 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
851 .unwrap();
852 assert_eq!(
853 cipher,
854 unhex(reference),
855 "ciphertext for {} must match Microsoft.Data.SqlClient",
856 value.type_name()
857 );
858 }
859 }
860
861 #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
865 #[test]
866 fn ae_normalization_matches_dotnet_uuid_date() {
867 fn unhex(s: &str) -> Vec<u8> {
868 (0..s.len())
869 .step_by(2)
870 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
871 .collect()
872 }
873
874 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
875 let enc = AeadEncryptor::new(&cek).unwrap();
876
877 for (value, reference) in [
878 (
879 SqlValue::Uuid(
880 uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
881 ),
882 "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
883 ),
884 (
885 SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
886 "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
887 ),
888 ] {
889 let norm = normalize_for_encryption(&value).unwrap();
890 let cipher = enc
891 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
892 .unwrap();
893 assert_eq!(
894 cipher,
895 unhex(reference),
896 "ciphertext for {} must match Microsoft.Data.SqlClient",
897 value.type_name()
898 );
899 }
900 }
901
902 #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
907 #[test]
908 fn ae_normalization_matches_dotnet_decimal_money() {
909 fn unhex(s: &str) -> Vec<u8> {
910 (0..s.len())
911 .step_by(2)
912 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
913 .collect()
914 }
915
916 let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
917 let enc = AeadEncryptor::new(&cek).unwrap();
918 let dec = rust_decimal::Decimal::new(123_456_789, 4); let money = rust_decimal::Decimal::new(123_400, 4); for (value, reference) in [
922 (
923 SqlValue::Decimal(dec),
924 "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
925 ),
926 (
927 SqlValue::Money(money),
928 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
929 ),
930 (
931 SqlValue::SmallMoney(money),
932 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
933 ),
934 ] {
935 let norm = normalize_for_encryption(&value).unwrap();
936 let cipher = enc
937 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
938 .unwrap();
939 assert_eq!(
940 cipher,
941 unhex(reference),
942 "ciphertext for {} must match Microsoft.Data.SqlClient",
943 value.type_name()
944 );
945 }
946 }
947
948 #[test]
949 fn test_encryption_config_defaults() {
950 let config = EncryptionConfig::new();
951 assert!(config.enabled);
952 assert!(config.cache_ceks);
953 assert!(!config.is_ready()); }
955
956 #[test]
957 fn test_result_set_encryption_info() {
958 let cek_table = CekTable::new();
959 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
960
961 assert!(!info.is_column_encrypted(0));
962 assert!(!info.is_column_encrypted(1));
963 assert!(!info.is_column_encrypted(2));
964
965 let metadata = CryptoMetadata {
966 cek_table_ordinal: 0,
967 base_user_type: 0,
968 base_col_type: 0x26,
969 base_type_info: tds_protocol::token::TypeInfo::default(),
970 algorithm_id: 2,
971 encryption_type: EncryptionTypeWire::Deterministic,
972 normalization_version: 1,
973 };
974
975 info.set_column_crypto(1, metadata);
976 assert!(!info.is_column_encrypted(0));
977 assert!(info.is_column_encrypted(1));
978 assert!(!info.is_column_encrypted(2));
979
980 assert_eq!(
981 info.get_encryption_type(1),
982 Some(EncryptionTypeWire::Deterministic)
983 );
984 }
985
986 #[test]
987 fn test_parameter_encryption_info() {
988 let mut info = ParameterEncryptionInfo::new();
989
990 assert!(!info.needs_encryption("@p1"));
991
992 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1);
993 info.add_parameter("@p1".to_string(), crypto);
994
995 assert!(info.needs_encryption("@p1"));
996 assert!(!info.needs_encryption("@p2"));
997
998 let param = info.get_parameter("@p1").unwrap();
999 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
1000 }
1001
1002 #[cfg(feature = "always-encrypted")]
1010 #[test]
1011 fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1012 use crate::row::{Column, Row};
1013 use crate::stream::ResultSet;
1014 use bytes::Bytes;
1015
1016 fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1017 let cols: Vec<Column> = (0..n_cols)
1018 .map(|i| Column::new(format!("c{i}"), i, "x"))
1019 .collect();
1020 let rows = rows
1021 .into_iter()
1022 .map(|vals| Row::from_values(cols.clone(), vals))
1023 .collect();
1024 ResultSet::new(cols, rows)
1025 }
1026
1027 let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); let rs1 = rs(
1032 9,
1033 vec![
1034 vec![
1035 SqlValue::Int(1),
1036 SqlValue::Int(7),
1037 SqlValue::Int(56),
1038 SqlValue::Int(1),
1039 SqlValue::Binary(mdv1.clone()),
1040 SqlValue::Binary(Bytes::from_static(b"env-a")),
1041 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1042 SqlValue::String("path-a".into()),
1043 SqlValue::String("RSA_OAEP".into()),
1044 ],
1045 vec![
1046 SqlValue::Int(1),
1047 SqlValue::Int(7),
1048 SqlValue::Int(56),
1049 SqlValue::Int(1),
1050 SqlValue::Binary(mdv1),
1051 SqlValue::Binary(Bytes::from_static(b"env-a2")),
1052 SqlValue::String("PROV_2".into()),
1053 SqlValue::String("path-a2".into()),
1054 SqlValue::String("RSA_OAEP".into()),
1055 ],
1056 vec![
1057 SqlValue::Int(2),
1058 SqlValue::Int(7),
1059 SqlValue::Int(57),
1060 SqlValue::Int(1),
1061 SqlValue::Binary(mdv2),
1062 SqlValue::Binary(Bytes::from_static(b"env-b")),
1063 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1064 SqlValue::String("path-b".into()),
1065 SqlValue::String("RSA_OAEP".into()),
1066 ],
1067 ],
1068 );
1069
1070 let rs2 = rs(
1072 6,
1073 vec![
1074 vec![
1075 SqlValue::Int(1),
1076 SqlValue::String("@det".into()),
1077 SqlValue::TinyInt(2),
1078 SqlValue::TinyInt(1),
1079 SqlValue::Int(1),
1080 SqlValue::TinyInt(1),
1081 ],
1082 vec![
1083 SqlValue::Int(2),
1084 SqlValue::String("@rand".into()),
1085 SqlValue::TinyInt(2),
1086 SqlValue::TinyInt(2),
1087 SqlValue::Int(2),
1088 SqlValue::TinyInt(1),
1089 ],
1090 vec![
1091 SqlValue::Int(3),
1092 SqlValue::String("@plain".into()),
1093 SqlValue::TinyInt(0),
1094 SqlValue::TinyInt(0),
1095 SqlValue::Int(0),
1096 SqlValue::TinyInt(0),
1097 ],
1098 ],
1099 );
1100
1101 let mut sets = vec![rs1, rs2];
1102 let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1103
1104 assert_eq!(info.cek_table.len(), 2);
1105 let e0 = info.cek_table.get(0).unwrap();
1106 assert_eq!(e0.cek_id, 56);
1107 assert_eq!(e0.cek_md_version, 1);
1108 assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1109 assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1110 assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1111 let e1 = info.cek_table.get(1).unwrap();
1112 assert_eq!(e1.cek_id, 57);
1113 assert_eq!(e1.cek_md_version, 255);
1114
1115 let det = info.get_parameter("@det").unwrap();
1116 assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1117 assert_eq!(det.algorithm_id, 2);
1118 assert_eq!(det.normalization_rule_version, 1);
1119 assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1120
1121 let rand = info.get_parameter("@rand").unwrap();
1122 assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1123 assert_eq!(
1124 rand.cek_ordinal, 1,
1125 "server ordinal 2 -> positional index 1"
1126 );
1127
1128 assert!(!info.needs_encryption("@plain"));
1129 assert_eq!(info.parameters.len(), 2);
1130 }
1131
1132 #[cfg(feature = "always-encrypted")]
1135 #[test]
1136 fn parse_describe_result_sets_rejects_missing_result_set() {
1137 use crate::row::{Column, Row};
1138 use crate::stream::ResultSet;
1139
1140 let cols: Vec<Column> = (0..9)
1141 .map(|i| Column::new(format!("c{i}"), i, "x"))
1142 .collect();
1143 let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1144 assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1145 }
1146}