1#[cfg(feature = "always-encrypted")]
103use std::collections::HashMap;
104
105use mssql_auth::KeyStoreProvider;
106#[cfg(feature = "always-encrypted")]
107use tds_protocol::crypto::{CekTable, CekTableEntry, EncryptionTypeWire};
108
109#[cfg(feature = "always-encrypted")]
110use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
111#[cfg(feature = "always-encrypted")]
112use mssql_types::SqlValue;
113#[cfg(feature = "always-encrypted")]
114use std::sync::Arc;
115
116#[cfg(feature = "always-encrypted")]
117use crate::{Error, row::Row, stream::ResultSet};
118#[cfg(feature = "always-encrypted")]
119use tds_protocol::crypto::CekValue;
120
121#[derive(Default)]
123pub struct EncryptionConfig {
124 pub enabled: bool,
126 providers: Vec<Box<dyn KeyStoreProvider>>,
128 pub cache_ceks: bool,
130}
131
132impl EncryptionConfig {
133 #[must_use]
135 pub fn new() -> Self {
136 Self {
137 enabled: true,
138 providers: Vec::new(),
139 cache_ceks: true,
140 }
141 }
142
143 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
145 self.providers.push(Box::new(provider));
146 }
147
148 #[must_use]
150 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
151 self.register_provider(provider);
152 self
153 }
154
155 #[must_use]
157 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
158 self.cache_ceks = enabled;
159 self
160 }
161
162 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
164 self.providers
165 .iter()
166 .find(|p| p.provider_name() == name)
167 .map(|p| p.as_ref())
168 }
169
170 #[must_use]
172 pub fn is_ready(&self) -> bool {
173 self.enabled && !self.providers.is_empty()
174 }
175}
176
177impl std::fmt::Debug for EncryptionConfig {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("EncryptionConfig")
180 .field("enabled", &self.enabled)
181 .field("provider_count", &self.providers.len())
182 .field("cache_ceks", &self.cache_ceks)
183 .finish()
184 }
185}
186
187#[cfg(feature = "always-encrypted")]
196pub(crate) struct EncryptionContext {
197 config: std::sync::Arc<EncryptionConfig>,
201 cek_cache: CekCache,
203 cache_enabled: bool,
205}
206
207#[cfg(feature = "always-encrypted")]
208impl EncryptionContext {
209 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
215 let cache_enabled = config.cache_ceks;
216 Self {
217 config,
218 cek_cache: CekCache::new(),
219 cache_enabled,
220 }
221 }
222
223 pub(crate) async fn get_encryptor(
230 &self,
231 cek_entry: &CekTableEntry,
232 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
233 let cache_key = CekCacheKey::new(
234 cek_entry.database_id,
235 cek_entry.cek_id,
236 cek_entry.cek_version,
237 );
238
239 if self.cache_enabled {
241 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
242 return Ok(encryptor);
243 }
244 }
245
246 let cek_value = cek_entry
248 .primary_value()
249 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
250
251 let provider = self
253 .config
254 .get_provider(&cek_value.key_store_provider_name)
255 .ok_or_else(|| {
256 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
257 })?;
258
259 let decrypted_cek = provider
261 .decrypt_cek(
262 &cek_value.cmk_path,
263 &cek_value.encryption_algorithm,
264 &cek_value.encrypted_value,
265 )
266 .await?;
267
268 if self.cache_enabled {
270 self.cek_cache.insert(cache_key, decrypted_cek)
271 } else {
272 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
274 }
275 }
276
277 pub(crate) async fn encrypt_value(
285 &self,
286 plaintext: &[u8],
287 cek_entry: &CekTableEntry,
288 encryption_type: EncryptionTypeWire,
289 ) -> Result<Vec<u8>, EncryptionError> {
290 let encryptor = self.get_encryptor(cek_entry).await?;
291
292 let enc_type = match encryption_type {
293 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
294 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
295 _ => {
296 return Err(EncryptionError::UnsupportedOperation(format!(
297 "unsupported encryption type: {encryption_type:?}"
298 )));
299 }
300 };
301
302 encryptor.encrypt(plaintext, enc_type)
303 }
304
305 pub fn has_provider(&self, name: &str) -> bool {
307 self.config.get_provider(name).is_some()
308 }
309}
310
311#[cfg(feature = "always-encrypted")]
312impl std::fmt::Debug for EncryptionContext {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.debug_struct("EncryptionContext")
315 .field("provider_count", &self.config.providers.len())
316 .field("cache_entries", &self.cek_cache.len())
317 .field("cache_enabled", &self.cache_enabled)
318 .finish()
319 }
320}
321
322#[cfg(feature = "always-encrypted")]
327#[derive(Debug, Clone)]
328pub(crate) struct ParameterEncryptionInfo {
329 pub cek_table: CekTable,
331 pub parameters: HashMap<String, ParameterCryptoInfo>,
333}
334
335#[cfg(feature = "always-encrypted")]
336impl ParameterEncryptionInfo {
337 pub fn new() -> Self {
339 Self {
340 cek_table: CekTable::new(),
341 parameters: HashMap::new(),
342 }
343 }
344
345 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
347 self.parameters.get(name)
348 }
349}
350
351#[cfg(feature = "always-encrypted")]
352impl Default for ParameterEncryptionInfo {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358#[cfg(feature = "always-encrypted")]
361#[derive(Debug, Clone)]
362pub(crate) struct ParameterCryptoInfo {
363 pub cek_ordinal: u16,
369 pub encryption_type: EncryptionTypeWire,
371 pub algorithm_id: u8,
373 pub normalization_rule_version: u8,
375}
376
377#[cfg(feature = "always-encrypted")]
387impl ParameterEncryptionInfo {
388 const RS1_MIN_COLS: usize = 9;
391 const RS2_MIN_COLS: usize = 6;
393
394 pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
399 if result_sets.len() < 2 {
400 return Err(Error::Protocol(format!(
401 "sp_describe_parameter_encryption returned {} result set(s), expected 2",
402 result_sets.len()
403 )));
404 }
405
406 let rs1_cols = result_sets[0].columns().len();
408 if rs1_cols < Self::RS1_MIN_COLS {
409 return Err(Error::Protocol(format!(
410 "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
411 Self::RS1_MIN_COLS
412 )));
413 }
414 let rs1_rows = result_sets[0].collect_all()?;
415
416 let mut entries: Vec<CekTableEntry> = Vec::new();
417 let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
419
420 for row in &rs1_rows {
421 let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
422 let value = CekValue {
423 encrypted_value: describe_varbinary(
424 row,
425 5,
426 "column_encryption_key_encrypted_value",
427 )?,
428 key_store_provider_name: describe_nvarchar(
429 row,
430 6,
431 "column_master_key_store_provider_name",
432 )?,
433 cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
434 encryption_algorithm: describe_nvarchar(
435 row,
436 8,
437 "column_encryption_key_encryption_algorithm_name",
438 )?,
439 };
440
441 if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
442 entries[idx as usize].values.push(value);
444 } else {
445 let idx = u16::try_from(entries.len()).map_err(|_| {
446 Error::Protocol(
447 "sp_describe_parameter_encryption returned too many CEKs".into(),
448 )
449 })?;
450 ordinal_to_index.insert(key_ordinal, idx);
451 entries.push(CekTableEntry {
452 database_id: describe_int(row, 1, "database_id")? as u32,
453 cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
454 cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
455 cek_md_version: describe_md_version(row, 4)?,
456 values: vec![value],
457 });
458 }
459 }
460 let cek_table = CekTable { entries };
461
462 let rs2_cols = result_sets[1].columns().len();
464 if rs2_cols < Self::RS2_MIN_COLS {
465 return Err(Error::Protocol(format!(
466 "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
467 Self::RS2_MIN_COLS
468 )));
469 }
470 let rs2_rows = result_sets[1].collect_all()?;
471
472 let mut parameters = HashMap::new();
473 for row in &rs2_rows {
474 let name = describe_nvarchar(row, 1, "parameter_name")?;
475 let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
476 if encryption_type_byte == 0 {
478 continue;
479 }
480 let encryption_type =
481 EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
482 Error::Protocol(format!(
483 "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
484 ))
485 })?;
486 let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
487 let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
488 let normalization_rule_version =
489 describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
490
491 let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
492 Error::Protocol(format!(
493 "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
494 ))
495 })?;
496
497 parameters.insert(
498 name,
499 ParameterCryptoInfo {
500 cek_ordinal,
501 encryption_type,
502 algorithm_id,
503 normalization_rule_version,
504 },
505 );
506 }
507
508 Ok(Self {
509 cek_table,
510 parameters,
511 })
512 }
513}
514
515#[cfg(feature = "always-encrypted")]
517fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
518 match row.get_raw(idx) {
519 Some(SqlValue::Int(v)) => Ok(v),
520 other => Err(describe_type_error(col, idx, "int", other.as_ref())),
521 }
522}
523
524#[cfg(feature = "always-encrypted")]
526fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
527 match row.get_raw(idx) {
528 Some(SqlValue::TinyInt(v)) => Ok(v),
529 other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
530 }
531}
532
533#[cfg(feature = "always-encrypted")]
535fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
536 match row.get_raw(idx) {
537 Some(SqlValue::String(v)) => Ok(v),
538 other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
539 }
540}
541
542#[cfg(feature = "always-encrypted")]
544fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
545 match row.get_raw(idx) {
546 Some(SqlValue::Binary(v)) => Ok(v),
547 other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
548 }
549}
550
551#[cfg(feature = "always-encrypted")]
553fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
554 match row.get_raw(idx) {
555 Some(SqlValue::Binary(b)) if b.len() == 8 => {
556 let mut bytes = [0u8; 8];
557 bytes.copy_from_slice(&b[..8]);
558 Ok(u64::from_le_bytes(bytes))
559 }
560 other => Err(describe_type_error(
561 "column_encryption_key_metadata_version",
562 idx,
563 "binary(8)",
564 other.as_ref(),
565 )),
566 }
567}
568
569#[cfg(feature = "always-encrypted")]
571fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
572 let got = got.map_or("missing", SqlValue::type_name);
573 Error::Protocol(format!(
574 "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
575 ))
576}
577
578#[cfg(feature = "always-encrypted")]
592pub fn normalize_for_encryption(
593 value: &SqlValue,
594 param_type: Option<mssql_types::EncryptedParamType>,
595) -> Result<Vec<u8>, EncryptionError> {
596 if let (Some(mssql_types::EncryptedParamType::Char { .. }), SqlValue::String(s)) =
599 (param_type, value)
600 {
601 let (encoded, _, had_errors) = encoding_rs::WINDOWS_1252.encode(s);
602 if had_errors {
609 return Err(EncryptionError::EncryptionFailed(
610 "char value contains characters not representable in Windows-1252 \
611 (the char column code page); use nchar for non-Latin text"
612 .to_string(),
613 ));
614 }
615 return Ok(encoded.into_owned());
616 }
617 #[cfg(feature = "chrono")]
620 {
621 use mssql_types::EncryptedParamType as E;
622 match (param_type, value) {
623 (Some(E::Time { scale }), SqlValue::Time(t)) => return normalize_ae_time(*t, scale),
624 (Some(E::DateTime2 { scale }), SqlValue::DateTime(dt)) => {
625 return normalize_ae_datetime2(*dt, scale);
626 }
627 (Some(E::DateTimeOffset { scale }), SqlValue::DateTimeOffset(dto)) => {
628 return normalize_ae_datetimeoffset(*dto, scale);
629 }
630 (Some(E::DateTime), SqlValue::DateTime(dt)) => {
631 let mut buf = bytes::BytesMut::with_capacity(8);
632 mssql_types::__private::encode_datetime_legacy(*dt, &mut buf);
633 return Ok(buf.to_vec());
634 }
635 _ => {}
636 }
637 }
638 match value {
639 SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
643 SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
644 SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
645 SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
646 SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
647 SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
649 SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
650 SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
652 SqlValue::Binary(b) => Ok(b.to_vec()),
654 #[cfg(feature = "uuid")]
657 SqlValue::Uuid(u) => {
658 let b = u.as_bytes();
659 Ok(vec![
660 b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
661 b[13], b[14], b[15],
662 ])
663 }
664 #[cfg(feature = "chrono")]
667 SqlValue::Date(d) => {
668 use chrono::Datelike;
669 let days = (d.num_days_from_ce() - 1) as u32;
670 Ok(days.to_le_bytes()[..3].to_vec())
671 }
672 #[cfg(feature = "decimal")]
675 SqlValue::Decimal(d) => {
676 let mut out = Vec::with_capacity(17);
677 out.push(u8::from(!d.is_sign_negative()));
678 out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
679 Ok(out)
680 }
681 #[cfg(feature = "decimal")]
684 SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
685 let cents = money_cents(d)?;
686 let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
687 out.extend_from_slice(&(cents as u32).to_le_bytes());
688 Ok(out)
689 }
690 #[cfg(feature = "chrono")]
693 SqlValue::SmallDateTime(dt) => {
694 let mut buf = bytes::BytesMut::with_capacity(4);
695 mssql_types::__private::encode_smalldatetime(*dt, &mut buf).map_err(|e| {
696 EncryptionError::UnsupportedOperation(format!("SMALLDATETIME: {e}"))
697 })?;
698 Ok(buf.to_vec())
699 }
700 other => Err(EncryptionError::UnsupportedOperation(format!(
701 "Always Encrypted parameter encryption is not yet implemented for {}",
702 other.type_name()
703 ))),
704 }
705}
706
707#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
710fn ae_date_bytes(d: chrono::NaiveDate) -> [u8; 3] {
711 use chrono::Datelike;
712 let days = (d.num_days_from_ce() - 1) as u32;
713 let b = days.to_le_bytes();
714 [b[0], b[1], b[2]]
715}
716
717#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
721fn normalize_ae_time(t: chrono::NaiveTime, scale: u8) -> Result<Vec<u8>, EncryptionError> {
722 use chrono::Timelike;
723 if scale > 7 {
724 return Err(EncryptionError::UnsupportedOperation(format!(
725 "time scale {scale} out of range (0–7)"
726 )));
727 }
728 let nanos =
729 u64::from(t.num_seconds_from_midnight()) * 1_000_000_000 + u64::from(t.nanosecond());
730 let ticks7 = nanos / 100;
740 let quantum = 10u64.pow(7 - u32::from(scale));
741 let quantized = (ticks7 / quantum) * quantum;
742 Ok(quantized.to_le_bytes()[..5].to_vec())
743}
744
745#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
748fn normalize_ae_datetime2(
749 dt: chrono::NaiveDateTime,
750 scale: u8,
751) -> Result<Vec<u8>, EncryptionError> {
752 let mut out = normalize_ae_time(dt.time(), scale)?;
753 out.extend_from_slice(&ae_date_bytes(dt.date()));
754 Ok(out)
755}
756
757#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
760fn normalize_ae_datetimeoffset(
761 dto: chrono::DateTime<chrono::FixedOffset>,
762 scale: u8,
763) -> Result<Vec<u8>, EncryptionError> {
764 use chrono::Offset;
765 let utc = dto.naive_utc();
766 let mut out = normalize_ae_time(utc.time(), scale)?;
767 out.extend_from_slice(&ae_date_bytes(utc.date()));
768 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
769 out.extend_from_slice(&offset_minutes.to_le_bytes());
770 Ok(out)
771}
772
773#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
776fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
777 let mantissa = value.mantissa();
778 let scale = value.scale();
779 let cents: i128 = if scale <= 4 {
780 mantissa
781 .checked_mul(10_i128.pow(4 - scale))
782 .ok_or_else(|| {
783 EncryptionError::UnsupportedOperation("MONEY value out of range".into())
784 })?
785 } else {
786 mantissa / 10_i128.pow(scale - 4)
787 };
788 i64::try_from(cents)
789 .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
790}
791
792#[cfg(test)]
793#[allow(clippy::unwrap_used, clippy::expect_used)]
794mod tests {
795 use super::*;
796
797 #[cfg(feature = "always-encrypted")]
807 #[test]
808 fn ae_normalization_matches_dotnet() {
809 use bytes::Bytes;
810
811 fn unhex(s: &str) -> Vec<u8> {
812 (0..s.len())
813 .step_by(2)
814 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
815 .collect()
816 }
817
818 let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
819 let enc = AeadEncryptor::new(&cek).unwrap();
820
821 for (value, reference) in [
822 (
823 SqlValue::Int(42),
824 "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
825 ),
826 (
827 SqlValue::String("Ada".to_string()),
828 "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
829 ),
830 (
831 SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
832 "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
833 ),
834 ] {
835 let norm = normalize_for_encryption(&value, None).unwrap();
836 let cipher = enc
837 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
838 .unwrap();
839 assert_eq!(
840 cipher,
841 unhex(reference),
842 "ciphertext for {} must match Microsoft.Data.SqlClient",
843 value.type_name()
844 );
845 }
846 }
847
848 #[cfg(feature = "always-encrypted")]
853 #[test]
854 fn ae_normalization_rejects_unnormalizable_value() {
855 assert!(normalize_for_encryption(&SqlValue::Null, None).is_err());
856 }
857
858 #[cfg(feature = "always-encrypted")]
864 #[test]
865 fn ae_normalization_matches_dotnet_numeric() {
866 fn unhex(s: &str) -> Vec<u8> {
867 (0..s.len())
868 .step_by(2)
869 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
870 .collect()
871 }
872
873 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
874 let enc = AeadEncryptor::new(&cek).unwrap();
875
876 for (value, reference) in [
877 (
878 SqlValue::BigInt(0x0102030405060708),
879 "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
880 ),
881 (
882 SqlValue::SmallInt(258),
883 "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
884 ),
885 (
886 SqlValue::TinyInt(200),
887 "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
888 ),
889 (
890 SqlValue::Bool(true),
891 "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
892 ),
893 (
894 SqlValue::Float(3.5),
895 "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
896 ),
897 (
898 SqlValue::Double(3.5),
899 "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
900 ),
901 ] {
902 let norm = normalize_for_encryption(&value, None).unwrap();
903 let cipher = enc
904 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
905 .unwrap();
906 assert_eq!(
907 cipher,
908 unhex(reference),
909 "ciphertext for {} must match Microsoft.Data.SqlClient",
910 value.type_name()
911 );
912 }
913 }
914
915 #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
919 #[test]
920 fn ae_normalization_matches_dotnet_uuid_date() {
921 fn unhex(s: &str) -> Vec<u8> {
922 (0..s.len())
923 .step_by(2)
924 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
925 .collect()
926 }
927
928 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
929 let enc = AeadEncryptor::new(&cek).unwrap();
930
931 for (value, reference) in [
932 (
933 SqlValue::Uuid(
934 uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
935 ),
936 "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
937 ),
938 (
939 SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
940 "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
941 ),
942 ] {
943 let norm = normalize_for_encryption(&value, None).unwrap();
944 let cipher = enc
945 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
946 .unwrap();
947 assert_eq!(
948 cipher,
949 unhex(reference),
950 "ciphertext for {} must match Microsoft.Data.SqlClient",
951 value.type_name()
952 );
953 }
954 }
955
956 #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
961 #[test]
962 fn ae_normalization_matches_dotnet_decimal_money() {
963 fn unhex(s: &str) -> Vec<u8> {
964 (0..s.len())
965 .step_by(2)
966 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
967 .collect()
968 }
969
970 let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
971 let enc = AeadEncryptor::new(&cek).unwrap();
972 let dec = rust_decimal::Decimal::new(123_456_789, 4); let money = rust_decimal::Decimal::new(123_400, 4); for (value, reference) in [
976 (
977 SqlValue::Decimal(dec),
978 "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
979 ),
980 (
981 SqlValue::Money(money),
982 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
983 ),
984 (
985 SqlValue::SmallMoney(money),
986 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
987 ),
988 ] {
989 let norm = normalize_for_encryption(&value, None).unwrap();
990 let cipher = enc
991 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
992 .unwrap();
993 assert_eq!(
994 cipher,
995 unhex(reference),
996 "ciphertext for {} must match Microsoft.Data.SqlClient",
997 value.type_name()
998 );
999 }
1000 }
1001
1002 #[cfg(all(feature = "always-encrypted", feature = "chrono"))]
1008 #[test]
1009 fn ae_normalization_matches_dotnet_temporal() {
1010 use mssql_types::EncryptedParamType as E;
1011 fn unhex(s: &str) -> Vec<u8> {
1012 (0..s.len())
1013 .step_by(2)
1014 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1015 .collect()
1016 }
1017
1018 let day = chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap();
1019 let dt = day.and_hms_nano_opt(13, 14, 15, 123_456_700).unwrap();
1020
1021 assert_eq!(
1023 normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 7 }))
1024 .unwrap(),
1025 unhex("07c4aaf46e"),
1026 );
1027 assert_eq!(
1029 normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 7 }))
1030 .unwrap(),
1031 unhex("07c4aaf46e8f460b"),
1032 );
1033 let dto = {
1035 use chrono::TimeZone;
1036 chrono::FixedOffset::east_opt(5 * 3600 + 30 * 60)
1037 .unwrap()
1038 .from_local_datetime(&dt)
1039 .single()
1040 .unwrap()
1041 };
1042 assert_eq!(
1043 normalize_for_encryption(
1044 &SqlValue::DateTimeOffset(dto),
1045 Some(E::DateTimeOffset { scale: 7 })
1046 )
1047 .unwrap(),
1048 unhex("0788f2da408f460b4a01"),
1049 );
1050
1051 assert_eq!(
1055 normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 3 }))
1056 .unwrap(),
1057 unhex("30b2aaf46e"),
1058 );
1059 assert_eq!(
1060 normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 3 }))
1061 .unwrap(),
1062 unhex("30b2aaf46e8f460b"),
1063 );
1064 assert_eq!(
1065 normalize_for_encryption(
1066 &SqlValue::DateTimeOffset(dto),
1067 Some(E::DateTimeOffset { scale: 3 })
1068 )
1069 .unwrap(),
1070 unhex("3076f2da408f460b4a01"),
1071 );
1072
1073 let dt_legacy = day.and_hms_milli_opt(13, 14, 15, 123).unwrap();
1075 assert_eq!(
1076 normalize_for_encryption(&SqlValue::DateTime(dt_legacy), Some(E::DateTime)).unwrap(),
1077 unhex("34b10000d925da00"),
1078 );
1079 let sdt = day.and_hms_opt(13, 14, 0).unwrap();
1081 assert_eq!(
1082 normalize_for_encryption(&SqlValue::SmallDateTime(sdt), None).unwrap(),
1083 unhex("34b11a03"),
1084 );
1085 }
1086
1087 #[cfg(feature = "always-encrypted")]
1092 #[test]
1093 fn ae_normalization_matches_dotnet_fixed_width() {
1094 use mssql_types::EncryptedParamType as E;
1095 fn unhex(s: &str) -> Vec<u8> {
1096 (0..s.len())
1097 .step_by(2)
1098 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1099 .collect()
1100 }
1101 assert_eq!(
1103 normalize_for_encryption(
1104 &SqlValue::String("Hello".to_string()),
1105 Some(E::Char { length: 10 })
1106 )
1107 .unwrap(),
1108 unhex("48656c6c6f"),
1109 );
1110 assert_eq!(
1112 normalize_for_encryption(
1113 &SqlValue::String("Hello".to_string()),
1114 Some(E::NChar { length: 10 })
1115 )
1116 .unwrap(),
1117 unhex("480065006c006c006f00"),
1118 );
1119 assert_eq!(
1121 normalize_for_encryption(
1122 &SqlValue::Binary(bytes::Bytes::from_static(&[1, 2, 3, 4, 5])),
1123 Some(E::Binary { length: 10 })
1124 )
1125 .unwrap(),
1126 unhex("0102030405"),
1127 );
1128 }
1129
1130 #[cfg(feature = "always-encrypted")]
1137 #[test]
1138 fn ae_char_rejects_non_windows_1252() {
1139 use mssql_types::EncryptedParamType as E;
1140 let r = normalize_for_encryption(
1141 &SqlValue::String("中".to_string()),
1142 Some(E::Char { length: 10 }),
1143 );
1144 assert!(
1145 r.is_err(),
1146 "non-Windows-1252 char must error, got {:?}",
1147 r.map(|b| b.iter().map(|x| format!("{x:02x}")).collect::<String>())
1148 );
1149 assert_eq!(
1151 normalize_for_encryption(
1152 &SqlValue::String("é".to_string()),
1153 Some(E::Char { length: 10 })
1154 )
1155 .unwrap(),
1156 vec![0xE9],
1157 );
1158 }
1159
1160 #[test]
1161 fn test_encryption_config_defaults() {
1162 let config = EncryptionConfig::new();
1163 assert!(config.enabled);
1164 assert!(config.cache_ceks);
1165 assert!(!config.is_ready()); }
1167
1168 #[cfg(feature = "always-encrypted")]
1176 #[test]
1177 fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1178 use crate::row::{Column, Row};
1179 use crate::stream::ResultSet;
1180 use bytes::Bytes;
1181
1182 fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1183 let cols: Vec<Column> = (0..n_cols)
1184 .map(|i| Column::new(format!("c{i}"), i, "x"))
1185 .collect();
1186 let rows = rows
1187 .into_iter()
1188 .map(|vals| Row::from_values(cols.clone(), vals))
1189 .collect();
1190 ResultSet::new(cols, rows)
1191 }
1192
1193 let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); let rs1 = rs(
1198 9,
1199 vec![
1200 vec![
1201 SqlValue::Int(1),
1202 SqlValue::Int(7),
1203 SqlValue::Int(56),
1204 SqlValue::Int(1),
1205 SqlValue::Binary(mdv1.clone()),
1206 SqlValue::Binary(Bytes::from_static(b"env-a")),
1207 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1208 SqlValue::String("path-a".into()),
1209 SqlValue::String("RSA_OAEP".into()),
1210 ],
1211 vec![
1212 SqlValue::Int(1),
1213 SqlValue::Int(7),
1214 SqlValue::Int(56),
1215 SqlValue::Int(1),
1216 SqlValue::Binary(mdv1),
1217 SqlValue::Binary(Bytes::from_static(b"env-a2")),
1218 SqlValue::String("PROV_2".into()),
1219 SqlValue::String("path-a2".into()),
1220 SqlValue::String("RSA_OAEP".into()),
1221 ],
1222 vec![
1223 SqlValue::Int(2),
1224 SqlValue::Int(7),
1225 SqlValue::Int(57),
1226 SqlValue::Int(1),
1227 SqlValue::Binary(mdv2),
1228 SqlValue::Binary(Bytes::from_static(b"env-b")),
1229 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1230 SqlValue::String("path-b".into()),
1231 SqlValue::String("RSA_OAEP".into()),
1232 ],
1233 ],
1234 );
1235
1236 let rs2 = rs(
1238 6,
1239 vec![
1240 vec![
1241 SqlValue::Int(1),
1242 SqlValue::String("@det".into()),
1243 SqlValue::TinyInt(2),
1244 SqlValue::TinyInt(1),
1245 SqlValue::Int(1),
1246 SqlValue::TinyInt(1),
1247 ],
1248 vec![
1249 SqlValue::Int(2),
1250 SqlValue::String("@rand".into()),
1251 SqlValue::TinyInt(2),
1252 SqlValue::TinyInt(2),
1253 SqlValue::Int(2),
1254 SqlValue::TinyInt(1),
1255 ],
1256 vec![
1257 SqlValue::Int(3),
1258 SqlValue::String("@plain".into()),
1259 SqlValue::TinyInt(0),
1260 SqlValue::TinyInt(0),
1261 SqlValue::Int(0),
1262 SqlValue::TinyInt(0),
1263 ],
1264 ],
1265 );
1266
1267 let mut sets = vec![rs1, rs2];
1268 let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1269
1270 assert_eq!(info.cek_table.len(), 2);
1271 let e0 = info.cek_table.get(0).unwrap();
1272 assert_eq!(e0.cek_id, 56);
1273 assert_eq!(e0.cek_md_version, 1);
1274 assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1275 assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1276 assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1277 let e1 = info.cek_table.get(1).unwrap();
1278 assert_eq!(e1.cek_id, 57);
1279 assert_eq!(e1.cek_md_version, 255);
1280
1281 let det = info.get_parameter("@det").unwrap();
1282 assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1283 assert_eq!(det.algorithm_id, 2);
1284 assert_eq!(det.normalization_rule_version, 1);
1285 assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1286
1287 let rand = info.get_parameter("@rand").unwrap();
1288 assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1289 assert_eq!(
1290 rand.cek_ordinal, 1,
1291 "server ordinal 2 -> positional index 1"
1292 );
1293
1294 assert!(info.get_parameter("@plain").is_none());
1295 assert_eq!(info.parameters.len(), 2);
1296 }
1297
1298 #[cfg(feature = "always-encrypted")]
1301 #[test]
1302 fn parse_describe_result_sets_rejects_missing_result_set() {
1303 use crate::row::{Column, Row};
1304 use crate::stream::ResultSet;
1305
1306 let cols: Vec<Column> = (0..9)
1307 .map(|i| Column::new(format!("c{i}"), i, "x"))
1308 .collect();
1309 let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1310 assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1311 }
1312}