1use std::collections::HashMap;
96
97use mssql_auth::KeyStoreProvider;
98use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
99
100#[cfg(feature = "always-encrypted")]
101use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
102#[cfg(feature = "always-encrypted")]
103use mssql_types::SqlValue;
104#[cfg(feature = "always-encrypted")]
105use std::sync::Arc;
106
107#[cfg(feature = "always-encrypted")]
108use crate::{Error, row::Row, stream::ResultSet};
109#[cfg(feature = "always-encrypted")]
110use tds_protocol::crypto::CekValue;
111
112#[derive(Default)]
114pub struct EncryptionConfig {
115 pub enabled: bool,
117 providers: Vec<Box<dyn KeyStoreProvider>>,
119 pub cache_ceks: bool,
121}
122
123impl EncryptionConfig {
124 #[must_use]
126 pub fn new() -> Self {
127 Self {
128 enabled: true,
129 providers: Vec::new(),
130 cache_ceks: true,
131 }
132 }
133
134 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
136 self.providers.push(Box::new(provider));
137 }
138
139 #[must_use]
141 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
142 self.register_provider(provider);
143 self
144 }
145
146 #[must_use]
148 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
149 self.cache_ceks = enabled;
150 self
151 }
152
153 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
155 self.providers
156 .iter()
157 .find(|p| p.provider_name() == name)
158 .map(|p| p.as_ref())
159 }
160
161 #[must_use]
163 pub fn is_ready(&self) -> bool {
164 self.enabled && !self.providers.is_empty()
165 }
166}
167
168impl std::fmt::Debug for EncryptionConfig {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct("EncryptionConfig")
171 .field("enabled", &self.enabled)
172 .field("provider_count", &self.providers.len())
173 .field("cache_ceks", &self.cache_ceks)
174 .finish()
175 }
176}
177
178#[cfg(feature = "always-encrypted")]
187pub struct EncryptionContext {
188 config: std::sync::Arc<EncryptionConfig>,
192 cek_cache: CekCache,
194 cache_enabled: bool,
196}
197
198#[cfg(feature = "always-encrypted")]
199impl EncryptionContext {
200 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
206 let cache_enabled = config.cache_ceks;
207 Self {
208 config,
209 cek_cache: CekCache::new(),
210 cache_enabled,
211 }
212 }
213
214 pub fn new(config: EncryptionConfig) -> Self {
216 Self::from_arc(std::sync::Arc::new(config))
217 }
218
219 pub async fn get_encryptor(
226 &self,
227 cek_entry: &CekTableEntry,
228 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
229 let cache_key = CekCacheKey::new(
230 cek_entry.database_id,
231 cek_entry.cek_id,
232 cek_entry.cek_version,
233 );
234
235 if self.cache_enabled {
237 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
238 return Ok(encryptor);
239 }
240 }
241
242 let cek_value = cek_entry
244 .primary_value()
245 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
246
247 let provider = self
249 .config
250 .get_provider(&cek_value.key_store_provider_name)
251 .ok_or_else(|| {
252 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
253 })?;
254
255 let decrypted_cek = provider
257 .decrypt_cek(
258 &cek_value.cmk_path,
259 &cek_value.encryption_algorithm,
260 &cek_value.encrypted_value,
261 )
262 .await?;
263
264 if self.cache_enabled {
266 self.cek_cache.insert(cache_key, decrypted_cek)
267 } else {
268 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
270 }
271 }
272
273 pub async fn encrypt_value(
281 &self,
282 plaintext: &[u8],
283 cek_entry: &CekTableEntry,
284 encryption_type: EncryptionTypeWire,
285 ) -> Result<Vec<u8>, EncryptionError> {
286 let encryptor = self.get_encryptor(cek_entry).await?;
287
288 let enc_type = match encryption_type {
289 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
290 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
291 _ => {
292 return Err(EncryptionError::UnsupportedOperation(format!(
293 "unsupported encryption type: {encryption_type:?}"
294 )));
295 }
296 };
297
298 encryptor.encrypt(plaintext, enc_type)
299 }
300
301 pub async fn decrypt_value(
308 &self,
309 ciphertext: &[u8],
310 cek_entry: &CekTableEntry,
311 ) -> Result<Vec<u8>, EncryptionError> {
312 let encryptor = self.get_encryptor(cek_entry).await?;
313 encryptor.decrypt(ciphertext)
314 }
315
316 pub fn clear_cache(&self) {
320 self.cek_cache.clear();
321 }
322
323 pub fn has_provider(&self, name: &str) -> bool {
325 self.config.get_provider(name).is_some()
326 }
327}
328
329#[cfg(feature = "always-encrypted")]
330impl std::fmt::Debug for EncryptionContext {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("EncryptionContext")
333 .field("provider_count", &self.config.providers.len())
334 .field("cache_entries", &self.cek_cache.len())
335 .field("cache_enabled", &self.cache_enabled)
336 .finish()
337 }
338}
339
340#[derive(Debug, Clone)]
345pub struct ResultSetEncryptionInfo {
346 pub cek_table: CekTable,
348 pub column_crypto: Vec<Option<CryptoMetadata>>,
350}
351
352impl ResultSetEncryptionInfo {
353 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
355 Self {
356 cek_table,
357 column_crypto: vec![None; column_count],
358 }
359 }
360
361 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
363 if ordinal < self.column_crypto.len() {
364 self.column_crypto[ordinal] = Some(metadata);
365 }
366 }
367
368 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
370 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
371 self.cek_table.get(crypto.cek_table_ordinal)
372 }
373
374 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
376 self.column_crypto
377 .get(ordinal)
378 .map(|c| c.is_some())
379 .unwrap_or(false)
380 }
381
382 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
384 self.column_crypto
385 .get(ordinal)?
386 .as_ref()
387 .map(|c| c.encryption_type)
388 }
389}
390
391#[derive(Debug, Clone)]
396pub struct ParameterEncryptionInfo {
397 pub cek_table: CekTable,
399 pub parameters: HashMap<String, ParameterCryptoInfo>,
401}
402
403impl ParameterEncryptionInfo {
404 pub fn new() -> Self {
406 Self {
407 cek_table: CekTable::new(),
408 parameters: HashMap::new(),
409 }
410 }
411
412 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
414 self.parameters.insert(name, info);
415 }
416
417 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
419 self.parameters.get(name)
420 }
421
422 pub fn needs_encryption(&self, name: &str) -> bool {
424 self.parameters.contains_key(name)
425 }
426}
427
428impl Default for ParameterEncryptionInfo {
429 fn default() -> Self {
430 Self::new()
431 }
432}
433
434#[derive(Debug, Clone)]
437pub struct ParameterCryptoInfo {
438 pub cek_ordinal: u16,
444 pub encryption_type: EncryptionTypeWire,
446 pub algorithm_id: u8,
448 pub normalization_rule_version: u8,
450}
451
452impl ParameterCryptoInfo {
453 pub fn new(
455 cek_ordinal: u16,
456 encryption_type: EncryptionTypeWire,
457 algorithm_id: u8,
458 normalization_rule_version: u8,
459 ) -> Self {
460 Self {
461 cek_ordinal,
462 encryption_type,
463 algorithm_id,
464 normalization_rule_version,
465 }
466 }
467}
468
469#[cfg(feature = "always-encrypted")]
479impl ParameterEncryptionInfo {
480 const RS1_MIN_COLS: usize = 9;
483 const RS2_MIN_COLS: usize = 6;
485
486 pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
491 if result_sets.len() < 2 {
492 return Err(Error::Protocol(format!(
493 "sp_describe_parameter_encryption returned {} result set(s), expected 2",
494 result_sets.len()
495 )));
496 }
497
498 let rs1_cols = result_sets[0].columns().len();
500 if rs1_cols < Self::RS1_MIN_COLS {
501 return Err(Error::Protocol(format!(
502 "sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
503 Self::RS1_MIN_COLS
504 )));
505 }
506 let rs1_rows = result_sets[0].collect_all()?;
507
508 let mut entries: Vec<CekTableEntry> = Vec::new();
509 let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
511
512 for row in &rs1_rows {
513 let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
514 let value = CekValue {
515 encrypted_value: describe_varbinary(
516 row,
517 5,
518 "column_encryption_key_encrypted_value",
519 )?,
520 key_store_provider_name: describe_nvarchar(
521 row,
522 6,
523 "column_master_key_store_provider_name",
524 )?,
525 cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
526 encryption_algorithm: describe_nvarchar(
527 row,
528 8,
529 "column_encryption_key_encryption_algorithm_name",
530 )?,
531 };
532
533 if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
534 entries[idx as usize].values.push(value);
536 } else {
537 let idx = u16::try_from(entries.len()).map_err(|_| {
538 Error::Protocol(
539 "sp_describe_parameter_encryption returned too many CEKs".into(),
540 )
541 })?;
542 ordinal_to_index.insert(key_ordinal, idx);
543 entries.push(CekTableEntry {
544 database_id: describe_int(row, 1, "database_id")? as u32,
545 cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
546 cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
547 cek_md_version: describe_md_version(row, 4)?,
548 values: vec![value],
549 });
550 }
551 }
552 let cek_table = CekTable { entries };
553
554 let rs2_cols = result_sets[1].columns().len();
556 if rs2_cols < Self::RS2_MIN_COLS {
557 return Err(Error::Protocol(format!(
558 "sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
559 Self::RS2_MIN_COLS
560 )));
561 }
562 let rs2_rows = result_sets[1].collect_all()?;
563
564 let mut parameters = HashMap::new();
565 for row in &rs2_rows {
566 let name = describe_nvarchar(row, 1, "parameter_name")?;
567 let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
568 if encryption_type_byte == 0 {
570 continue;
571 }
572 let encryption_type =
573 EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
574 Error::Protocol(format!(
575 "sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
576 ))
577 })?;
578 let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
579 let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
580 let normalization_rule_version =
581 describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
582
583 let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
584 Error::Protocol(format!(
585 "sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
586 ))
587 })?;
588
589 parameters.insert(
590 name,
591 ParameterCryptoInfo {
592 cek_ordinal,
593 encryption_type,
594 algorithm_id,
595 normalization_rule_version,
596 },
597 );
598 }
599
600 Ok(Self {
601 cek_table,
602 parameters,
603 })
604 }
605}
606
607#[cfg(feature = "always-encrypted")]
609fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
610 match row.get_raw(idx) {
611 Some(SqlValue::Int(v)) => Ok(v),
612 other => Err(describe_type_error(col, idx, "int", other.as_ref())),
613 }
614}
615
616#[cfg(feature = "always-encrypted")]
618fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
619 match row.get_raw(idx) {
620 Some(SqlValue::TinyInt(v)) => Ok(v),
621 other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
622 }
623}
624
625#[cfg(feature = "always-encrypted")]
627fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
628 match row.get_raw(idx) {
629 Some(SqlValue::String(v)) => Ok(v),
630 other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
631 }
632}
633
634#[cfg(feature = "always-encrypted")]
636fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
637 match row.get_raw(idx) {
638 Some(SqlValue::Binary(v)) => Ok(v),
639 other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
640 }
641}
642
643#[cfg(feature = "always-encrypted")]
645fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
646 match row.get_raw(idx) {
647 Some(SqlValue::Binary(b)) if b.len() == 8 => {
648 let mut bytes = [0u8; 8];
649 bytes.copy_from_slice(&b[..8]);
650 Ok(u64::from_le_bytes(bytes))
651 }
652 other => Err(describe_type_error(
653 "column_encryption_key_metadata_version",
654 idx,
655 "binary(8)",
656 other.as_ref(),
657 )),
658 }
659}
660
661#[cfg(feature = "always-encrypted")]
663fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
664 let got = got.map_or("missing", SqlValue::type_name);
665 Error::Protocol(format!(
666 "sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
667 ))
668}
669
670#[cfg(feature = "always-encrypted")]
684pub fn normalize_for_encryption(
685 value: &SqlValue,
686 param_type: Option<mssql_types::EncryptedParamType>,
687) -> Result<Vec<u8>, EncryptionError> {
688 if let (Some(mssql_types::EncryptedParamType::Char { .. }), SqlValue::String(s)) =
691 (param_type, value)
692 {
693 return Ok(encoding_rs::WINDOWS_1252.encode(s).0.into_owned());
694 }
695 #[cfg(feature = "chrono")]
698 {
699 use mssql_types::EncryptedParamType as E;
700 match (param_type, value) {
701 (Some(E::Time { scale }), SqlValue::Time(t)) => return normalize_ae_time(*t, scale),
702 (Some(E::DateTime2 { scale }), SqlValue::DateTime(dt)) => {
703 return normalize_ae_datetime2(*dt, scale);
704 }
705 (Some(E::DateTimeOffset { scale }), SqlValue::DateTimeOffset(dto)) => {
706 return normalize_ae_datetimeoffset(*dto, scale);
707 }
708 (Some(E::DateTime), SqlValue::DateTime(dt)) => {
709 let mut buf = bytes::BytesMut::with_capacity(8);
710 mssql_types::encode::encode_datetime_legacy(*dt, &mut buf);
711 return Ok(buf.to_vec());
712 }
713 _ => {}
714 }
715 }
716 match value {
717 SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
721 SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
722 SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
723 SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
724 SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
725 SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
727 SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
728 SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
730 SqlValue::Binary(b) => Ok(b.to_vec()),
732 #[cfg(feature = "uuid")]
735 SqlValue::Uuid(u) => {
736 let b = u.as_bytes();
737 Ok(vec![
738 b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
739 b[13], b[14], b[15],
740 ])
741 }
742 #[cfg(feature = "chrono")]
745 SqlValue::Date(d) => {
746 use chrono::Datelike;
747 let days = (d.num_days_from_ce() - 1) as u32;
748 Ok(days.to_le_bytes()[..3].to_vec())
749 }
750 #[cfg(feature = "decimal")]
753 SqlValue::Decimal(d) => {
754 let mut out = Vec::with_capacity(17);
755 out.push(u8::from(!d.is_sign_negative()));
756 out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
757 Ok(out)
758 }
759 #[cfg(feature = "decimal")]
762 SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
763 let cents = money_cents(d)?;
764 let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
765 out.extend_from_slice(&(cents as u32).to_le_bytes());
766 Ok(out)
767 }
768 #[cfg(feature = "chrono")]
771 SqlValue::SmallDateTime(dt) => {
772 let mut buf = bytes::BytesMut::with_capacity(4);
773 mssql_types::encode::encode_smalldatetime(*dt, &mut buf).map_err(|e| {
774 EncryptionError::UnsupportedOperation(format!("SMALLDATETIME: {e}"))
775 })?;
776 Ok(buf.to_vec())
777 }
778 other => Err(EncryptionError::UnsupportedOperation(format!(
779 "Always Encrypted parameter encryption is not yet implemented for {}",
780 other.type_name()
781 ))),
782 }
783}
784
785#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
788fn ae_date_bytes(d: chrono::NaiveDate) -> [u8; 3] {
789 use chrono::Datelike;
790 let days = (d.num_days_from_ce() - 1) as u32;
791 let b = days.to_le_bytes();
792 [b[0], b[1], b[2]]
793}
794
795#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
799fn normalize_ae_time(t: chrono::NaiveTime, scale: u8) -> Result<Vec<u8>, EncryptionError> {
800 use chrono::Timelike;
801 if scale > 7 {
802 return Err(EncryptionError::UnsupportedOperation(format!(
803 "time scale {scale} out of range (0–7)"
804 )));
805 }
806 let nanos =
807 u64::from(t.num_seconds_from_midnight()) * 1_000_000_000 + u64::from(t.nanosecond());
808 let divisor = 10u64.pow(9 - u32::from(scale));
809 let ticks = (nanos + divisor / 2) / divisor;
810 let len = match scale {
811 0..=2 => 3,
812 3..=4 => 4,
813 _ => 5,
814 };
815 Ok(ticks.to_le_bytes()[..len].to_vec())
816}
817
818#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
821fn normalize_ae_datetime2(
822 dt: chrono::NaiveDateTime,
823 scale: u8,
824) -> Result<Vec<u8>, EncryptionError> {
825 let mut out = normalize_ae_time(dt.time(), scale)?;
826 out.extend_from_slice(&ae_date_bytes(dt.date()));
827 Ok(out)
828}
829
830#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
833fn normalize_ae_datetimeoffset(
834 dto: chrono::DateTime<chrono::FixedOffset>,
835 scale: u8,
836) -> Result<Vec<u8>, EncryptionError> {
837 use chrono::Offset;
838 let utc = dto.naive_utc();
839 let mut out = normalize_ae_time(utc.time(), scale)?;
840 out.extend_from_slice(&ae_date_bytes(utc.date()));
841 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
842 out.extend_from_slice(&offset_minutes.to_le_bytes());
843 Ok(out)
844}
845
846#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
849fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
850 let mantissa = value.mantissa();
851 let scale = value.scale();
852 let cents: i128 = if scale <= 4 {
853 mantissa
854 .checked_mul(10_i128.pow(4 - scale))
855 .ok_or_else(|| {
856 EncryptionError::UnsupportedOperation("MONEY value out of range".into())
857 })?
858 } else {
859 mantissa / 10_i128.pow(scale - 4)
860 };
861 i64::try_from(cents)
862 .map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
863}
864
865#[cfg(test)]
866#[allow(clippy::unwrap_used, clippy::expect_used)]
867mod tests {
868 use super::*;
869
870 #[cfg(feature = "always-encrypted")]
876 #[test]
877 fn ae_normalization_matches_dotnet() {
878 use bytes::Bytes;
879
880 fn unhex(s: &str) -> Vec<u8> {
881 (0..s.len())
882 .step_by(2)
883 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
884 .collect()
885 }
886
887 let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
888 let enc = AeadEncryptor::new(&cek).unwrap();
889
890 for (value, reference) in [
891 (
892 SqlValue::Int(42),
893 "01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
894 ),
895 (
896 SqlValue::String("Ada".to_string()),
897 "01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
898 ),
899 (
900 SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
901 "01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
902 ),
903 ] {
904 let norm = normalize_for_encryption(&value, None).unwrap();
905 let cipher = enc
906 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
907 .unwrap();
908 assert_eq!(
909 cipher,
910 unhex(reference),
911 "ciphertext for {} must match Microsoft.Data.SqlClient",
912 value.type_name()
913 );
914 }
915 }
916
917 #[cfg(feature = "always-encrypted")]
922 #[test]
923 fn ae_normalization_rejects_unnormalizable_value() {
924 assert!(normalize_for_encryption(&SqlValue::Null, None).is_err());
925 }
926
927 #[cfg(feature = "always-encrypted")]
933 #[test]
934 fn ae_normalization_matches_dotnet_numeric() {
935 fn unhex(s: &str) -> Vec<u8> {
936 (0..s.len())
937 .step_by(2)
938 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
939 .collect()
940 }
941
942 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
943 let enc = AeadEncryptor::new(&cek).unwrap();
944
945 for (value, reference) in [
946 (
947 SqlValue::BigInt(0x0102030405060708),
948 "01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
949 ),
950 (
951 SqlValue::SmallInt(258),
952 "012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
953 ),
954 (
955 SqlValue::TinyInt(200),
956 "01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
957 ),
958 (
959 SqlValue::Bool(true),
960 "01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
961 ),
962 (
963 SqlValue::Float(3.5),
964 "017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
965 ),
966 (
967 SqlValue::Double(3.5),
968 "0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
969 ),
970 ] {
971 let norm = normalize_for_encryption(&value, None).unwrap();
972 let cipher = enc
973 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
974 .unwrap();
975 assert_eq!(
976 cipher,
977 unhex(reference),
978 "ciphertext for {} must match Microsoft.Data.SqlClient",
979 value.type_name()
980 );
981 }
982 }
983
984 #[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
988 #[test]
989 fn ae_normalization_matches_dotnet_uuid_date() {
990 fn unhex(s: &str) -> Vec<u8> {
991 (0..s.len())
992 .step_by(2)
993 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
994 .collect()
995 }
996
997 let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
998 let enc = AeadEncryptor::new(&cek).unwrap();
999
1000 for (value, reference) in [
1001 (
1002 SqlValue::Uuid(
1003 uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
1004 ),
1005 "01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
1006 ),
1007 (
1008 SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
1009 "0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
1010 ),
1011 ] {
1012 let norm = normalize_for_encryption(&value, None).unwrap();
1013 let cipher = enc
1014 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
1015 .unwrap();
1016 assert_eq!(
1017 cipher,
1018 unhex(reference),
1019 "ciphertext for {} must match Microsoft.Data.SqlClient",
1020 value.type_name()
1021 );
1022 }
1023 }
1024
1025 #[cfg(all(feature = "always-encrypted", feature = "decimal"))]
1030 #[test]
1031 fn ae_normalization_matches_dotnet_decimal_money() {
1032 fn unhex(s: &str) -> Vec<u8> {
1033 (0..s.len())
1034 .step_by(2)
1035 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1036 .collect()
1037 }
1038
1039 let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
1040 let enc = AeadEncryptor::new(&cek).unwrap();
1041 let dec = rust_decimal::Decimal::new(123_456_789, 4); let money = rust_decimal::Decimal::new(123_400, 4); for (value, reference) in [
1045 (
1046 SqlValue::Decimal(dec),
1047 "018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
1048 ),
1049 (
1050 SqlValue::Money(money),
1051 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
1052 ),
1053 (
1054 SqlValue::SmallMoney(money),
1055 "01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
1056 ),
1057 ] {
1058 let norm = normalize_for_encryption(&value, None).unwrap();
1059 let cipher = enc
1060 .encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
1061 .unwrap();
1062 assert_eq!(
1063 cipher,
1064 unhex(reference),
1065 "ciphertext for {} must match Microsoft.Data.SqlClient",
1066 value.type_name()
1067 );
1068 }
1069 }
1070
1071 #[cfg(all(feature = "always-encrypted", feature = "chrono"))]
1077 #[test]
1078 fn ae_normalization_matches_dotnet_temporal() {
1079 use mssql_types::EncryptedParamType as E;
1080 fn unhex(s: &str) -> Vec<u8> {
1081 (0..s.len())
1082 .step_by(2)
1083 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1084 .collect()
1085 }
1086
1087 let day = chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap();
1088 let dt = day.and_hms_nano_opt(13, 14, 15, 123_456_700).unwrap();
1089
1090 assert_eq!(
1092 normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 7 }))
1093 .unwrap(),
1094 unhex("07c4aaf46e"),
1095 );
1096 assert_eq!(
1098 normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 7 }))
1099 .unwrap(),
1100 unhex("07c4aaf46e8f460b"),
1101 );
1102 let dto = {
1104 use chrono::TimeZone;
1105 chrono::FixedOffset::east_opt(5 * 3600 + 30 * 60)
1106 .unwrap()
1107 .from_local_datetime(&dt)
1108 .single()
1109 .unwrap()
1110 };
1111 assert_eq!(
1112 normalize_for_encryption(
1113 &SqlValue::DateTimeOffset(dto),
1114 Some(E::DateTimeOffset { scale: 7 })
1115 )
1116 .unwrap(),
1117 unhex("0788f2da408f460b4a01"),
1118 );
1119 let dt_legacy = day.and_hms_milli_opt(13, 14, 15, 123).unwrap();
1121 assert_eq!(
1122 normalize_for_encryption(&SqlValue::DateTime(dt_legacy), Some(E::DateTime)).unwrap(),
1123 unhex("34b10000d925da00"),
1124 );
1125 let sdt = day.and_hms_opt(13, 14, 0).unwrap();
1127 assert_eq!(
1128 normalize_for_encryption(&SqlValue::SmallDateTime(sdt), None).unwrap(),
1129 unhex("34b11a03"),
1130 );
1131 }
1132
1133 #[cfg(feature = "always-encrypted")]
1138 #[test]
1139 fn ae_normalization_matches_dotnet_fixed_width() {
1140 use mssql_types::EncryptedParamType as E;
1141 fn unhex(s: &str) -> Vec<u8> {
1142 (0..s.len())
1143 .step_by(2)
1144 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1145 .collect()
1146 }
1147 assert_eq!(
1149 normalize_for_encryption(
1150 &SqlValue::String("Hello".to_string()),
1151 Some(E::Char { length: 10 })
1152 )
1153 .unwrap(),
1154 unhex("48656c6c6f"),
1155 );
1156 assert_eq!(
1158 normalize_for_encryption(
1159 &SqlValue::String("Hello".to_string()),
1160 Some(E::NChar { length: 10 })
1161 )
1162 .unwrap(),
1163 unhex("480065006c006c006f00"),
1164 );
1165 assert_eq!(
1167 normalize_for_encryption(
1168 &SqlValue::Binary(bytes::Bytes::from_static(&[1, 2, 3, 4, 5])),
1169 Some(E::Binary { length: 10 })
1170 )
1171 .unwrap(),
1172 unhex("0102030405"),
1173 );
1174 }
1175
1176 #[test]
1177 fn test_encryption_config_defaults() {
1178 let config = EncryptionConfig::new();
1179 assert!(config.enabled);
1180 assert!(config.cache_ceks);
1181 assert!(!config.is_ready()); }
1183
1184 #[test]
1185 fn test_result_set_encryption_info() {
1186 let cek_table = CekTable::new();
1187 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
1188
1189 assert!(!info.is_column_encrypted(0));
1190 assert!(!info.is_column_encrypted(1));
1191 assert!(!info.is_column_encrypted(2));
1192
1193 let metadata = CryptoMetadata {
1194 cek_table_ordinal: 0,
1195 base_user_type: 0,
1196 base_col_type: 0x26,
1197 base_type_info: tds_protocol::token::TypeInfo::default(),
1198 algorithm_id: 2,
1199 encryption_type: EncryptionTypeWire::Deterministic,
1200 normalization_version: 1,
1201 };
1202
1203 info.set_column_crypto(1, metadata);
1204 assert!(!info.is_column_encrypted(0));
1205 assert!(info.is_column_encrypted(1));
1206 assert!(!info.is_column_encrypted(2));
1207
1208 assert_eq!(
1209 info.get_encryption_type(1),
1210 Some(EncryptionTypeWire::Deterministic)
1211 );
1212 }
1213
1214 #[test]
1215 fn test_parameter_encryption_info() {
1216 let mut info = ParameterEncryptionInfo::new();
1217
1218 assert!(!info.needs_encryption("@p1"));
1219
1220 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1);
1221 info.add_parameter("@p1".to_string(), crypto);
1222
1223 assert!(info.needs_encryption("@p1"));
1224 assert!(!info.needs_encryption("@p2"));
1225
1226 let param = info.get_parameter("@p1").unwrap();
1227 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
1228 }
1229
1230 #[cfg(feature = "always-encrypted")]
1238 #[test]
1239 fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
1240 use crate::row::{Column, Row};
1241 use crate::stream::ResultSet;
1242 use bytes::Bytes;
1243
1244 fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
1245 let cols: Vec<Column> = (0..n_cols)
1246 .map(|i| Column::new(format!("c{i}"), i, "x"))
1247 .collect();
1248 let rows = rows
1249 .into_iter()
1250 .map(|vals| Row::from_values(cols.clone(), vals))
1251 .collect();
1252 ResultSet::new(cols, rows)
1253 }
1254
1255 let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]); let rs1 = rs(
1260 9,
1261 vec![
1262 vec![
1263 SqlValue::Int(1),
1264 SqlValue::Int(7),
1265 SqlValue::Int(56),
1266 SqlValue::Int(1),
1267 SqlValue::Binary(mdv1.clone()),
1268 SqlValue::Binary(Bytes::from_static(b"env-a")),
1269 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1270 SqlValue::String("path-a".into()),
1271 SqlValue::String("RSA_OAEP".into()),
1272 ],
1273 vec![
1274 SqlValue::Int(1),
1275 SqlValue::Int(7),
1276 SqlValue::Int(56),
1277 SqlValue::Int(1),
1278 SqlValue::Binary(mdv1),
1279 SqlValue::Binary(Bytes::from_static(b"env-a2")),
1280 SqlValue::String("PROV_2".into()),
1281 SqlValue::String("path-a2".into()),
1282 SqlValue::String("RSA_OAEP".into()),
1283 ],
1284 vec![
1285 SqlValue::Int(2),
1286 SqlValue::Int(7),
1287 SqlValue::Int(57),
1288 SqlValue::Int(1),
1289 SqlValue::Binary(mdv2),
1290 SqlValue::Binary(Bytes::from_static(b"env-b")),
1291 SqlValue::String("IN_MEMORY_KEY_STORE".into()),
1292 SqlValue::String("path-b".into()),
1293 SqlValue::String("RSA_OAEP".into()),
1294 ],
1295 ],
1296 );
1297
1298 let rs2 = rs(
1300 6,
1301 vec![
1302 vec![
1303 SqlValue::Int(1),
1304 SqlValue::String("@det".into()),
1305 SqlValue::TinyInt(2),
1306 SqlValue::TinyInt(1),
1307 SqlValue::Int(1),
1308 SqlValue::TinyInt(1),
1309 ],
1310 vec![
1311 SqlValue::Int(2),
1312 SqlValue::String("@rand".into()),
1313 SqlValue::TinyInt(2),
1314 SqlValue::TinyInt(2),
1315 SqlValue::Int(2),
1316 SqlValue::TinyInt(1),
1317 ],
1318 vec![
1319 SqlValue::Int(3),
1320 SqlValue::String("@plain".into()),
1321 SqlValue::TinyInt(0),
1322 SqlValue::TinyInt(0),
1323 SqlValue::Int(0),
1324 SqlValue::TinyInt(0),
1325 ],
1326 ],
1327 );
1328
1329 let mut sets = vec![rs1, rs2];
1330 let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
1331
1332 assert_eq!(info.cek_table.len(), 2);
1333 let e0 = info.cek_table.get(0).unwrap();
1334 assert_eq!(e0.cek_id, 56);
1335 assert_eq!(e0.cek_md_version, 1);
1336 assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
1337 assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
1338 assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
1339 let e1 = info.cek_table.get(1).unwrap();
1340 assert_eq!(e1.cek_id, 57);
1341 assert_eq!(e1.cek_md_version, 255);
1342
1343 let det = info.get_parameter("@det").unwrap();
1344 assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
1345 assert_eq!(det.algorithm_id, 2);
1346 assert_eq!(det.normalization_rule_version, 1);
1347 assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
1348
1349 let rand = info.get_parameter("@rand").unwrap();
1350 assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
1351 assert_eq!(
1352 rand.cek_ordinal, 1,
1353 "server ordinal 2 -> positional index 1"
1354 );
1355
1356 assert!(!info.needs_encryption("@plain"));
1357 assert_eq!(info.parameters.len(), 2);
1358 }
1359
1360 #[cfg(feature = "always-encrypted")]
1363 #[test]
1364 fn parse_describe_result_sets_rejects_missing_result_set() {
1365 use crate::row::{Column, Row};
1366 use crate::stream::ResultSet;
1367
1368 let cols: Vec<Column> = (0..9)
1369 .map(|i| Column::new(format!("c{i}"), i, "x"))
1370 .collect();
1371 let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
1372 assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
1373 }
1374}