1#[cfg(feature = "always-encrypted")]
103use std::collections::HashMap;
104
105use mssql_auth::KeyStoreProvider;
106#[cfg(feature = "always-encrypted")]
107use tds_protocol::crypto::{CekTable, CekTableEntry, EncryptionTypeWire};
108
109#[cfg(feature = "always-encrypted")]
110use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
111#[cfg(feature = "always-encrypted")]
112use mssql_types::SqlValue;
113#[cfg(feature = "always-encrypted")]
114use std::sync::Arc;
115
116#[cfg(feature = "always-encrypted")]
117use crate::{Error, row::Row, stream::ResultSet};
118#[cfg(feature = "always-encrypted")]
119use tds_protocol::crypto::CekValue;
120
121#[derive(Default)]
123pub struct EncryptionConfig {
124 pub enabled: bool,
126 providers: Vec<Box<dyn KeyStoreProvider>>,
128 pub cache_ceks: bool,
130}
131
132impl EncryptionConfig {
133 #[must_use]
135 pub fn new() -> Self {
136 Self {
137 enabled: true,
138 providers: Vec::new(),
139 cache_ceks: true,
140 }
141 }
142
143 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
145 self.providers.push(Box::new(provider));
146 }
147
148 #[must_use]
150 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
151 self.register_provider(provider);
152 self
153 }
154
155 #[must_use]
157 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
158 self.cache_ceks = enabled;
159 self
160 }
161
162 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
164 self.providers
165 .iter()
166 .find(|p| p.provider_name() == name)
167 .map(|p| p.as_ref())
168 }
169
170 #[must_use]
172 pub fn is_ready(&self) -> bool {
173 self.enabled && !self.providers.is_empty()
174 }
175}
176
177impl std::fmt::Debug for EncryptionConfig {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("EncryptionConfig")
180 .field("enabled", &self.enabled)
181 .field("provider_count", &self.providers.len())
182 .field("cache_ceks", &self.cache_ceks)
183 .finish()
184 }
185}
186
187#[cfg(feature = "always-encrypted")]
196pub(crate) struct EncryptionContext {
197 config: std::sync::Arc<EncryptionConfig>,
201 cek_cache: CekCache,
203 cache_enabled: bool,
205}
206
207#[cfg(feature = "always-encrypted")]
208impl EncryptionContext {
209 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
215 let cache_enabled = config.cache_ceks;
216 Self {
217 config,
218 cek_cache: CekCache::new(),
219 cache_enabled,
220 }
221 }
222
223 pub(crate) async fn get_encryptor(
230 &self,
231 cek_entry: &CekTableEntry,
232 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
233 let cache_key = CekCacheKey::new(
234 cek_entry.database_id,
235 cek_entry.cek_id,
236 cek_entry.cek_version,
237 );
238
239 if self.cache_enabled {
241 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
242 return Ok(encryptor);
243 }
244 }
245
246 let cek_value = cek_entry
248 .primary_value()
249 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
250
251 let provider = self
253 .config
254 .get_provider(&cek_value.key_store_provider_name)
255 .ok_or_else(|| {
256 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
257 })?;
258
259 let decrypted_cek = provider
261 .decrypt_cek(
262 &cek_value.cmk_path,
263 &cek_value.encryption_algorithm,
264 &cek_value.encrypted_value,
265 )
266 .await?;
267
268 if self.cache_enabled {
270 self.cek_cache.insert(cache_key, decrypted_cek)
271 } else {
272 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
274 }
275 }
276
277 pub(crate) async fn encrypt_value(
285 &self,
286 plaintext: &[u8],
287 cek_entry: &CekTableEntry,
288 encryption_type: EncryptionTypeWire,
289 ) -> Result<Vec<u8>, EncryptionError> {
290 let encryptor = self.get_encryptor(cek_entry).await?;
291
292 let enc_type = match encryption_type {
293 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
294 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
295 _ => {
296 return Err(EncryptionError::UnsupportedOperation(format!(
297 "unsupported encryption type: {encryption_type:?}"
298 )));
299 }
300 };
301
302 encryptor.encrypt(plaintext, enc_type)
303 }
304
305 pub fn has_provider(&self, name: &str) -> bool {
307 self.config.get_provider(name).is_some()
308 }
309}
310
311#[cfg(feature = "always-encrypted")]
312impl std::fmt::Debug for EncryptionContext {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.debug_struct("EncryptionContext")
315 .field("provider_count", &self.config.providers.len())
316 .field("cache_entries", &self.cek_cache.len())
317 .field("cache_enabled", &self.cache_enabled)
318 .finish()
319 }
320}
321
322#[cfg(feature = "always-encrypted")]
327#[derive(Debug, Clone)]
328pub(crate) struct ParameterEncryptionInfo {
329 pub cek_table: CekTable,
331 pub parameters: HashMap<String, ParameterCryptoInfo>,
333}
334
335#[cfg(feature = "always-encrypted")]
336impl ParameterEncryptionInfo {
337 pub fn new() -> Self {
339 Self {
340 cek_table: CekTable::new(),
341 parameters: HashMap::new(),
342 }
343 }
344
345 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
347 self.parameters.get(name)
348 }
349}
350
351#[cfg(feature = "always-encrypted")]
352impl Default for ParameterEncryptionInfo {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358#[cfg(feature = "always-encrypted")]
361#[derive(Debug, Clone)]
362pub(crate) struct ParameterCryptoInfo {
363 pub cek_ordinal: u16,
369 pub encryption_type: EncryptionTypeWire,
371 pub algorithm_id: u8,
373 pub normalization_rule_version: u8,
375}
376
377#[cfg(feature = "always-encrypted")]
387impl ParameterEncryptionInfo {
388 const RS1_MIN_COLS: usize = 9;
391 const RS2_MIN_COLS: usize = 6;
393
394 pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
399 if result_sets.len() < 2 {
400 return Err(Error::Protocol(format!(
401 "sp_describe_parameter_encryption returned {} result set(s), expected 2",
402 result_sets.len()
403 )));
404 }
405
406 let rs1_cols = result_sets[0].columns().len();
408 if rs1_cols < Self::RS1_MIN_COLS {
409 return Err(Error::Protocol(format!(
410 "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
411 Self::RS1_MIN_COLS
412 )));
413 }
414 let rs1_rows = result_sets[0].collect_all()?;
415
416 let mut entries: Vec<CekTableEntry> = Vec::new();
417 let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
419
420 for row in &rs1_rows {
421 let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
422 let value = CekValue {
423 encrypted_value: describe_varbinary(
424 row,
425 5,
426 "column_encryption_key_encrypted_value",
427 )?,
428 key_store_provider_name: describe_nvarchar(
429 row,
430 6,
431 "column_master_key_store_provider_name",
432 )?,
433 cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
434 encryption_algorithm: describe_nvarchar(
435 row,
436 8,
437 "column_encryption_key_encryption_algorithm_name",
438 )?,
439 };
440
441 if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
442 entries[idx as usize].values.push(value);
444 } else {
445 let idx = u16::try_from(entries.len()).map_err(|_| {
446 Error::Protocol(
447 "sp_describe_parameter_encryption returned too many CEKs".into(),
448 )
449 })?;
450 ordinal_to_index.insert(key_ordinal, idx);
451 entries.push(CekTableEntry {
452 database_id: describe_int(row, 1, "database_id")? as u32,
453 cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
454 cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
455 cek_md_version: describe_md_version(row, 4)?,
456 values: vec![value],
457 });
458 }
459 }
460 let cek_table = CekTable { entries };
461
462 let rs2_cols = result_sets[1].columns().len();
464 if rs2_cols < Self::RS2_MIN_COLS {
465 return Err(Error::Protocol(format!(
466 "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
467 Self::RS2_MIN_COLS
468 )));
469 }
470 let rs2_rows = result_sets[1].collect_all()?;
471
472 let mut parameters = HashMap::new();
473 for row in &rs2_rows {
474 let name = describe_nvarchar(row, 1, "parameter_name")?;
475 let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
476 if encryption_type_byte == 0 {
478 continue;
479 }
480 let encryption_type =
481 EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
482 Error::Protocol(format!(
483 "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
484 ))
485 })?;
486 let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
487 let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
488 let normalization_rule_version =
489 describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
490
491 let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
492 Error::Protocol(format!(
493 "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
494 ))
495 })?;
496
497 parameters.insert(
498 name,
499 ParameterCryptoInfo {
500 cek_ordinal,
501 encryption_type,
502 algorithm_id,
503 normalization_rule_version,
504 },
505 );
506 }
507
508 Ok(Self {
509 cek_table,
510 parameters,
511 })
512 }
513}
514
515#[cfg(feature = "always-encrypted")]
517fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
518 match row.get_raw(idx) {
519 Some(SqlValue::Int(v)) => Ok(v),
520 other => Err(describe_type_error(col, idx, "int", other.as_ref())),
521 }
522}
523
524#[cfg(feature = "always-encrypted")]
526fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
527 match row.get_raw(idx) {
528 Some(SqlValue::TinyInt(v)) => Ok(v),
529 other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
530 }
531}
532
533#[cfg(feature = "always-encrypted")]
535fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
536 match row.get_raw(idx) {
537 Some(SqlValue::String(v)) => Ok(v),
538 other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
539 }
540}
541
542#[cfg(feature = "always-encrypted")]
544fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
545 match row.get_raw(idx) {
546 Some(SqlValue::Binary(v)) => Ok(v),
547 other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
548 }
549}
550
551#[cfg(feature = "always-encrypted")]
553fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
554 match row.get_raw(idx) {
555 Some(SqlValue::Binary(b)) if b.len() == 8 => {
556 let mut bytes = [0u8; 8];
557 bytes.copy_from_slice(&b[..8]);
558 Ok(u64::from_le_bytes(bytes))
559 }
560 other => Err(describe_type_error(
561 "column_encryption_key_metadata_version",
562 idx,
563 "binary(8)",
564 other.as_ref(),
565 )),
566 }
567}
568
569#[cfg(feature = "always-encrypted")]
571fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
572 let got = got.map_or("missing", SqlValue::type_name);
573 Error::Protocol(format!(
574 "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
575 ))
576}
577
578#[cfg(feature = "always-encrypted")]
592pub fn normalize_for_encryption(
593 value: &SqlValue,
594 param_type: Option<mssql_types::EncryptedParamType>,
595) -> Result<Vec<u8>, EncryptionError> {
596 if let (Some(mssql_types::EncryptedParamType::Char { .. }), SqlValue::String(s)) =
599 (param_type, value)
600 {
601 let (encoded, _, had_errors) = encoding_rs::WINDOWS_1252.encode(s);
602 if had_errors {
609 return Err(EncryptionError::EncryptionFailed(
610 "char value contains characters not representable in Windows-1252 \
611 (the char column code page); use nchar for non-Latin text"
612 .to_string(),
613 ));
614 }
615 return Ok(encoded.into_owned());
616 }
617 #[cfg(feature = "chrono")]
620 {
621 use mssql_types::EncryptedParamType as E;
622 match (param_type, value) {
623 (Some(E::Time { scale }), SqlValue::Time(t)) => return normalize_ae_time(*t, scale),
624 (Some(E::DateTime2 { scale }), SqlValue::DateTime(dt)) => {
625 return normalize_ae_datetime2(*dt, scale);
626 }
627 (Some(E::DateTimeOffset { scale }), SqlValue::DateTimeOffset(dto)) => {
628 return normalize_ae_datetimeoffset(*dto, scale);
629 }
630 (Some(E::DateTime), SqlValue::DateTime(dt)) => {
631 let mut buf = bytes::BytesMut::with_capacity(8);
632 mssql_types::__private::encode_datetime_legacy(*dt, &mut buf);
633 return Ok(buf.to_vec());
634 }
635 _ => {}
636 }
637 }
638 match value {
639 SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
643 SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
644 SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
645 SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
646 SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
647 SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
649 SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
650 SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
652 SqlValue::Binary(b) => Ok(b.to_vec()),
654 #[cfg(feature = "uuid")]
657 SqlValue::Uuid(u) => {
658 let b = u.as_bytes();
659 Ok(vec![
660 b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
661 b[13], b[14], b[15],
662 ])
663 }
664 #[cfg(feature = "chrono")]
667 SqlValue::Date(d) => {
668 use chrono::Datelike;
669 let days = (d.num_days_from_ce() - 1) as u32;
670 Ok(days.to_le_bytes()[..3].to_vec())
671 }
672 #[cfg(feature = "decimal")]
675 SqlValue::Decimal(d) => {
676 let mut out = Vec::with_capacity(17);
677 out.push(u8::from(!d.is_sign_negative()));
678 out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
679 Ok(out)
680 }
681 #[cfg(feature = "decimal")]
684 SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
685 let cents = money_cents(d)?;
686 let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
687 out.extend_from_slice(&(cents as u32).to_le_bytes());
688 Ok(out)
689 }
690 #[cfg(feature = "chrono")]
693 SqlValue::SmallDateTime(dt) => {
694 let mut buf = bytes::BytesMut::with_capacity(4);
695 mssql_types::__private::encode_smalldatetime(*dt, &mut buf).map_err(|e| {
696 EncryptionError::UnsupportedOperation(format!("SMALLDATETIME: {e}"))
697 })?;
698 Ok(buf.to_vec())
699 }
700 other => Err(EncryptionError::UnsupportedOperation(format!(
701 "Always Encrypted parameter encryption is not yet implemented for {}",
702 other.type_name()
703 ))),
704 }
705}
706
707#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
710fn ae_date_bytes(d: chrono::NaiveDate) -> [u8; 3] {
711 use chrono::Datelike;
712 let days = (d.num_days_from_ce() - 1) as u32;
713 let b = days.to_le_bytes();
714 [b[0], b[1], b[2]]
715}
716
717#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
721fn normalize_ae_time(t: chrono::NaiveTime, scale: u8) -> Result<Vec<u8>, EncryptionError> {
722 use chrono::Timelike;
723 if scale > 7 {
724 return Err(EncryptionError::UnsupportedOperation(format!(
725 "time scale {scale} out of range (0–7)"
726 )));
727 }
728 let nanos =
729 u64::from(t.num_seconds_from_midnight()) * 1_000_000_000 + u64::from(t.nanosecond());
730 let ticks7 = nanos / 100;
740 let quantum = 10u64.pow(7 - u32::from(scale));
741 let quantized = (ticks7 / quantum) * quantum;
742 Ok(quantized.to_le_bytes()[..5].to_vec())
743}
744
745#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
748fn normalize_ae_datetime2(
749 dt: chrono::NaiveDateTime,
750 scale: u8,
751) -> Result<Vec<u8>, EncryptionError> {
752 let mut out = normalize_ae_time(dt.time(), scale)?;
753 out.extend_from_slice(&ae_date_bytes(dt.date()));
754 Ok(out)
755}
756
757#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
760fn normalize_ae_datetimeoffset(
761 dto: chrono::DateTime<chrono::FixedOffset>,
762 scale: u8,
763) -> Result<Vec<u8>, EncryptionError> {
764 use chrono::Offset;
765 let utc = dto.naive_utc();
766 let mut out = normalize_ae_time(utc.time(), scale)?;
767 out.extend_from_slice(&ae_date_bytes(utc.date()));
768 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
769 out.extend_from_slice(&offset_minutes.to_le_bytes());
770 Ok(out)
771}
772
773#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
776fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
777 let mantissa = value.mantissa();
778 let scale = value.scale();
779 let cents: i128 = if scale <= 4 {
780 mantissa
781 .checked_mul(10_i128.pow(4 - scale))
782 .ok_or_else(|| {
783 EncryptionError::UnsupportedOperation("MONEY value out of range".into())
784 })?
785 } else {
786 mantissa / 10_i128.pow(scale - 4)
787 };
788 i64::try_from(cents)
789 .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
790}
791
792#[cfg(test)]
793#[allow(clippy::unwrap_used, clippy::expect_used)]
794mod tests {
795 use super::*;
796
797 #[cfg(feature = "always-encrypted")]
803 #[test]
804 fn ae_normalization_matches_dotnet() {
805 use bytes::Bytes;
806
807 fn unhex(s: &str) -> Vec<u8> {
808 (0..s.len())
809 .step_by(2)
810 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
811 .collect()
812 }
813
814 let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
815 let enc = AeadEncryptor::new(&cek).unwrap();
816
817 for (value, reference) in [
818 (
819 SqlValue::Int(42),
820 "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
821 ),
822 (
823 SqlValue::String("Ada".to_string()),
824 "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
825 ),
826 (
827 SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
828 "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
829 ),
830 ] {
831 let norm = normalize_for_encryption(&value, None).unwrap();
832 let cipher = enc
833 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
834 .unwrap();
835 assert_eq!(
836 cipher,
837 unhex(reference),
838 "ciphertext for {} must match Microsoft.Data.SqlClient",
839 value.type_name()
840 );
841 }
842 }
843
844 #[cfg(feature = "always-encrypted")]
849 #[test]
850 fn ae_normalization_rejects_unnormalizable_value() {
851 assert!(normalize_for_encryption(&SqlValue::Null, None).is_err());
852 }
853
854 #[cfg(feature = "always-encrypted")]
860 #[test]
861 fn ae_normalization_matches_dotnet_numeric() {
862 fn unhex(s: &str) -> Vec<u8> {
863 (0..s.len())
864 .step_by(2)
865 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
866 .collect()
867 }
868
869 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
870 let enc = AeadEncryptor::new(&cek).unwrap();
871
872 for (value, reference) in [
873 (
874 SqlValue::BigInt(0x0102030405060708),
875 "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
876 ),
877 (
878 SqlValue::SmallInt(258),
879 "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
880 ),
881 (
882 SqlValue::TinyInt(200),
883 "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
884 ),
885 (
886 SqlValue::Bool(true),
887 "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
888 ),
889 (
890 SqlValue::Float(3.5),
891 "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
892 ),
893 (
894 SqlValue::Double(3.5),
895 "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
896 ),
897 ] {
898 let norm = normalize_for_encryption(&value, None).unwrap();
899 let cipher = enc
900 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
901 .unwrap();
902 assert_eq!(
903 cipher,
904 unhex(reference),
905 "ciphertext for {} must match Microsoft.Data.SqlClient",
906 value.type_name()
907 );
908 }
909 }
910
911 #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
915 #[test]
916 fn ae_normalization_matches_dotnet_uuid_date() {
917 fn unhex(s: &str) -> Vec<u8> {
918 (0..s.len())
919 .step_by(2)
920 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
921 .collect()
922 }
923
924 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
925 let enc = AeadEncryptor::new(&cek).unwrap();
926
927 for (value, reference) in [
928 (
929 SqlValue::Uuid(
930 uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
931 ),
932 "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
933 ),
934 (
935 SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
936 "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
937 ),
938 ] {
939 let norm = normalize_for_encryption(&value, None).unwrap();
940 let cipher = enc
941 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
942 .unwrap();
943 assert_eq!(
944 cipher,
945 unhex(reference),
946 "ciphertext for {} must match Microsoft.Data.SqlClient",
947 value.type_name()
948 );
949 }
950 }
951
952 #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
957 #[test]
958 fn ae_normalization_matches_dotnet_decimal_money() {
959 fn unhex(s: &str) -> Vec<u8> {
960 (0..s.len())
961 .step_by(2)
962 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
963 .collect()
964 }
965
966 let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
967 let enc = AeadEncryptor::new(&cek).unwrap();
968 let dec = rust_decimal::Decimal::new(123_456_789, 4); let money = rust_decimal::Decimal::new(123_400, 4); for (value, reference) in [
972 (
973 SqlValue::Decimal(dec),
974 "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
975 ),
976 (
977 SqlValue::Money(money),
978 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
979 ),
980 (
981 SqlValue::SmallMoney(money),
982 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
983 ),
984 ] {
985 let norm = normalize_for_encryption(&value, None).unwrap();
986 let cipher = enc
987 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
988 .unwrap();
989 assert_eq!(
990 cipher,
991 unhex(reference),
992 "ciphertext for {} must match Microsoft.Data.SqlClient",
993 value.type_name()
994 );
995 }
996 }
997
998 #[cfg(all(feature = "always-encrypted", feature = "chrono"))]
1004 #[test]
1005 fn ae_normalization_matches_dotnet_temporal() {
1006 use mssql_types::EncryptedParamType as E;
1007 fn unhex(s: &str) -> Vec<u8> {
1008 (0..s.len())
1009 .step_by(2)
1010 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1011 .collect()
1012 }
1013
1014 let day = chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap();
1015 let dt = day.and_hms_nano_opt(13, 14, 15, 123_456_700).unwrap();
1016
1017 assert_eq!(
1019 normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 7 }))
1020 .unwrap(),
1021 unhex("07c4aaf46e"),
1022 );
1023 assert_eq!(
1025 normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 7 }))
1026 .unwrap(),
1027 unhex("07c4aaf46e8f460b"),
1028 );
1029 let dto = {
1031 use chrono::TimeZone;
1032 chrono::FixedOffset::east_opt(5 * 3600 + 30 * 60)
1033 .unwrap()
1034 .from_local_datetime(&dt)
1035 .single()
1036 .unwrap()
1037 };
1038 assert_eq!(
1039 normalize_for_encryption(
1040 &SqlValue::DateTimeOffset(dto),
1041 Some(E::DateTimeOffset { scale: 7 })
1042 )
1043 .unwrap(),
1044 unhex("0788f2da408f460b4a01"),
1045 );
1046
1047 assert_eq!(
1051 normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 3 }))
1052 .unwrap(),
1053 unhex("30b2aaf46e"),
1054 );
1055 assert_eq!(
1056 normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 3 }))
1057 .unwrap(),
1058 unhex("30b2aaf46e8f460b"),
1059 );
1060 assert_eq!(
1061 normalize_for_encryption(
1062 &SqlValue::DateTimeOffset(dto),
1063 Some(E::DateTimeOffset { scale: 3 })
1064 )
1065 .unwrap(),
1066 unhex("3076f2da408f460b4a01"),
1067 );
1068
1069 let dt_legacy = day.and_hms_milli_opt(13, 14, 15, 123).unwrap();
1071 assert_eq!(
1072 normalize_for_encryption(&SqlValue::DateTime(dt_legacy), Some(E::DateTime)).unwrap(),
1073 unhex("34b10000d925da00"),
1074 );
1075 let sdt = day.and_hms_opt(13, 14, 0).unwrap();
1077 assert_eq!(
1078 normalize_for_encryption(&SqlValue::SmallDateTime(sdt), None).unwrap(),
1079 unhex("34b11a03"),
1080 );
1081 }
1082
1083 #[cfg(feature = "always-encrypted")]
1088 #[test]
1089 fn ae_normalization_matches_dotnet_fixed_width() {
1090 use mssql_types::EncryptedParamType as E;
1091 fn unhex(s: &str) -> Vec<u8> {
1092 (0..s.len())
1093 .step_by(2)
1094 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1095 .collect()
1096 }
1097 assert_eq!(
1099 normalize_for_encryption(
1100 &SqlValue::String("Hello".to_string()),
1101 Some(E::Char { length: 10 })
1102 )
1103 .unwrap(),
1104 unhex("48656c6c6f"),
1105 );
1106 assert_eq!(
1108 normalize_for_encryption(
1109 &SqlValue::String("Hello".to_string()),
1110 Some(E::NChar { length: 10 })
1111 )
1112 .unwrap(),
1113 unhex("480065006c006c006f00"),
1114 );
1115 assert_eq!(
1117 normalize_for_encryption(
1118 &SqlValue::Binary(bytes::Bytes::from_static(&[1, 2, 3, 4, 5])),
1119 Some(E::Binary { length: 10 })
1120 )
1121 .unwrap(),
1122 unhex("0102030405"),
1123 );
1124 }
1125
1126 #[cfg(feature = "always-encrypted")]
1133 #[test]
1134 fn ae_char_rejects_non_windows_1252() {
1135 use mssql_types::EncryptedParamType as E;
1136 let r = normalize_for_encryption(
1137 &SqlValue::String("中".to_string()),
1138 Some(E::Char { length: 10 }),
1139 );
1140 assert!(
1141 r.is_err(),
1142 "non-Windows-1252 char must error, got {:?}",
1143 r.map(|b| b.iter().map(|x| format!("{x:02x}")).collect::<String>())
1144 );
1145 assert_eq!(
1147 normalize_for_encryption(
1148 &SqlValue::String("é".to_string()),
1149 Some(E::Char { length: 10 })
1150 )
1151 .unwrap(),
1152 vec![0xE9],
1153 );
1154 }
1155
1156 #[test]
1157 fn test_encryption_config_defaults() {
1158 let config = EncryptionConfig::new();
1159 assert!(config.enabled);
1160 assert!(config.cache_ceks);
1161 assert!(!config.is_ready()); }
1163
1164 #[cfg(feature = "always-encrypted")]
1172 #[test]
1173 fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1174 use crate::row::{Column, Row};
1175 use crate::stream::ResultSet;
1176 use bytes::Bytes;
1177
1178 fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1179 let cols: Vec<Column> = (0..n_cols)
1180 .map(|i| Column::new(format!("c{i}"), i, "x"))
1181 .collect();
1182 let rows = rows
1183 .into_iter()
1184 .map(|vals| Row::from_values(cols.clone(), vals))
1185 .collect();
1186 ResultSet::new(cols, rows)
1187 }
1188
1189 let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); let rs1 = rs(
1194 9,
1195 vec![
1196 vec![
1197 SqlValue::Int(1),
1198 SqlValue::Int(7),
1199 SqlValue::Int(56),
1200 SqlValue::Int(1),
1201 SqlValue::Binary(mdv1.clone()),
1202 SqlValue::Binary(Bytes::from_static(b"env-a")),
1203 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1204 SqlValue::String("path-a".into()),
1205 SqlValue::String("RSA_OAEP".into()),
1206 ],
1207 vec![
1208 SqlValue::Int(1),
1209 SqlValue::Int(7),
1210 SqlValue::Int(56),
1211 SqlValue::Int(1),
1212 SqlValue::Binary(mdv1),
1213 SqlValue::Binary(Bytes::from_static(b"env-a2")),
1214 SqlValue::String("PROV_2".into()),
1215 SqlValue::String("path-a2".into()),
1216 SqlValue::String("RSA_OAEP".into()),
1217 ],
1218 vec![
1219 SqlValue::Int(2),
1220 SqlValue::Int(7),
1221 SqlValue::Int(57),
1222 SqlValue::Int(1),
1223 SqlValue::Binary(mdv2),
1224 SqlValue::Binary(Bytes::from_static(b"env-b")),
1225 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1226 SqlValue::String("path-b".into()),
1227 SqlValue::String("RSA_OAEP".into()),
1228 ],
1229 ],
1230 );
1231
1232 let rs2 = rs(
1234 6,
1235 vec![
1236 vec![
1237 SqlValue::Int(1),
1238 SqlValue::String("@det".into()),
1239 SqlValue::TinyInt(2),
1240 SqlValue::TinyInt(1),
1241 SqlValue::Int(1),
1242 SqlValue::TinyInt(1),
1243 ],
1244 vec![
1245 SqlValue::Int(2),
1246 SqlValue::String("@rand".into()),
1247 SqlValue::TinyInt(2),
1248 SqlValue::TinyInt(2),
1249 SqlValue::Int(2),
1250 SqlValue::TinyInt(1),
1251 ],
1252 vec![
1253 SqlValue::Int(3),
1254 SqlValue::String("@plain".into()),
1255 SqlValue::TinyInt(0),
1256 SqlValue::TinyInt(0),
1257 SqlValue::Int(0),
1258 SqlValue::TinyInt(0),
1259 ],
1260 ],
1261 );
1262
1263 let mut sets = vec![rs1, rs2];
1264 let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1265
1266 assert_eq!(info.cek_table.len(), 2);
1267 let e0 = info.cek_table.get(0).unwrap();
1268 assert_eq!(e0.cek_id, 56);
1269 assert_eq!(e0.cek_md_version, 1);
1270 assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1271 assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1272 assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1273 let e1 = info.cek_table.get(1).unwrap();
1274 assert_eq!(e1.cek_id, 57);
1275 assert_eq!(e1.cek_md_version, 255);
1276
1277 let det = info.get_parameter("@det").unwrap();
1278 assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1279 assert_eq!(det.algorithm_id, 2);
1280 assert_eq!(det.normalization_rule_version, 1);
1281 assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1282
1283 let rand = info.get_parameter("@rand").unwrap();
1284 assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1285 assert_eq!(
1286 rand.cek_ordinal, 1,
1287 "server ordinal 2 -> positional index 1"
1288 );
1289
1290 assert!(info.get_parameter("@plain").is_none());
1291 assert_eq!(info.parameters.len(), 2);
1292 }
1293
1294 #[cfg(feature = "always-encrypted")]
1297 #[test]
1298 fn parse_describe_result_sets_rejects_missing_result_set() {
1299 use crate::row::{Column, Row};
1300 use crate::stream::ResultSet;
1301
1302 let cols: Vec<Column> = (0..9)
1303 .map(|i| Column::new(format!("c{i}"), i, "x"))
1304 .collect();
1305 let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1306 assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1307 }
1308}