#[cfg(feature = "always-encrypted")]
use std::collections::HashMap;
use mssql_auth::KeyStoreProvider;
#[cfg(feature = "always-encrypted")]
use tds_protocol::crypto::{CekTable, CekTableEntry, EncryptionTypeWire};
#[cfg(feature = "always-encrypted")]
use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
#[cfg(feature = "always-encrypted")]
use mssql_types::SqlValue;
#[cfg(feature = "always-encrypted")]
use std::sync::Arc;
#[cfg(feature = "always-encrypted")]
use crate::{Error, row::Row, stream::ResultSet};
#[cfg(feature = "always-encrypted")]
use tds_protocol::crypto::CekValue;
#[derive(Default)]
pub struct EncryptionConfig {
pub enabled: bool,
providers: Vec<Box<dyn KeyStoreProvider>>,
pub cache_ceks: bool,
}
impl EncryptionConfig {
#[must_use]
pub fn new() -> Self {
Self {
enabled: true,
providers: Vec::new(),
cache_ceks: true,
}
}
pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
self.providers.push(Box::new(provider));
}
#[must_use]
pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
self.register_provider(provider);
self
}
#[must_use]
pub fn with_cek_caching(mut self, enabled: bool) -> Self {
self.cache_ceks = enabled;
self
}
pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
self.providers
.iter()
.find(|p| p.provider_name() == name)
.map(|p| p.as_ref())
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.enabled && !self.providers.is_empty()
}
}
impl std::fmt::Debug for EncryptionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionConfig")
.field("enabled", &self.enabled)
.field("provider_count", &self.providers.len())
.field("cache_ceks", &self.cache_ceks)
.finish()
}
}
#[cfg(feature = "always-encrypted")]
pub(crate) struct EncryptionContext {
config: std::sync::Arc<EncryptionConfig>,
cek_cache: CekCache,
cache_enabled: bool,
}
#[cfg(feature = "always-encrypted")]
impl EncryptionContext {
pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
let cache_enabled = config.cache_ceks;
Self {
config,
cek_cache: CekCache::new(),
cache_enabled,
}
}
pub(crate) async fn get_encryptor(
&self,
cek_entry: &CekTableEntry,
) -> Result<Arc<AeadEncryptor>, EncryptionError> {
let cache_key = CekCacheKey::new(
cek_entry.database_id,
cek_entry.cek_id,
cek_entry.cek_version,
);
if self.cache_enabled {
if let Some(encryptor) = self.cek_cache.get(&cache_key) {
return Ok(encryptor);
}
}
let cek_value = cek_entry
.primary_value()
.ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
let provider = self
.config
.get_provider(&cek_value.key_store_provider_name)
.ok_or_else(|| {
EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
})?;
let decrypted_cek = provider
.decrypt_cek(
&cek_value.cmk_path,
&cek_value.encryption_algorithm,
&cek_value.encrypted_value,
)
.await?;
if self.cache_enabled {
self.cek_cache.insert(cache_key, decrypted_cek)
} else {
Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
}
}
pub(crate) async fn encrypt_value(
&self,
plaintext: &[u8],
cek_entry: &CekTableEntry,
encryption_type: EncryptionTypeWire,
) -> Result<Vec<u8>, EncryptionError> {
let encryptor = self.get_encryptor(cek_entry).await?;
let enc_type = match encryption_type {
EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
_ => {
return Err(EncryptionError::UnsupportedOperation(format!(
"unsupported encryption type: {encryption_type:?}"
)));
}
};
encryptor.encrypt(plaintext, enc_type)
}
pub fn has_provider(&self, name: &str) -> bool {
self.config.get_provider(name).is_some()
}
}
#[cfg(feature = "always-encrypted")]
impl std::fmt::Debug for EncryptionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionContext")
.field("provider_count", &self.config.providers.len())
.field("cache_entries", &self.cek_cache.len())
.field("cache_enabled", &self.cache_enabled)
.finish()
}
}
#[cfg(feature = "always-encrypted")]
#[derive(Debug, Clone)]
pub(crate) struct ParameterEncryptionInfo {
pub cek_table: CekTable,
pub parameters: HashMap<String, ParameterCryptoInfo>,
}
#[cfg(feature = "always-encrypted")]
impl ParameterEncryptionInfo {
pub fn new() -> Self {
Self {
cek_table: CekTable::new(),
parameters: HashMap::new(),
}
}
pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
self.parameters.get(name)
}
}
#[cfg(feature = "always-encrypted")]
impl Default for ParameterEncryptionInfo {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "always-encrypted")]
#[derive(Debug, Clone)]
pub(crate) struct ParameterCryptoInfo {
pub cek_ordinal: u16,
pub encryption_type: EncryptionTypeWire,
pub algorithm_id: u8,
pub normalization_rule_version: u8,
}
#[cfg(feature = "always-encrypted")]
impl ParameterEncryptionInfo {
const RS1_MIN_COLS: usize = 9;
const RS2_MIN_COLS: usize = 6;
pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
if result_sets.len() < 2 {
return Err(Error::Protocol(format!(
"sp_describe_parameter_encryption returned {} result set(s), expected 2",
result_sets.len()
)));
}
let rs1_cols = result_sets[0].columns().len();
if rs1_cols < Self::RS1_MIN_COLS {
return Err(Error::Protocol(format!(
"sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
Self::RS1_MIN_COLS
)));
}
let rs1_rows = result_sets[0].collect_all()?;
let mut entries: Vec<CekTableEntry> = Vec::new();
let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
for row in &rs1_rows {
let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
let value = CekValue {
encrypted_value: describe_varbinary(
row,
5,
"column_encryption_key_encrypted_value",
)?,
key_store_provider_name: describe_nvarchar(
row,
6,
"column_master_key_store_provider_name",
)?,
cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
encryption_algorithm: describe_nvarchar(
row,
8,
"column_encryption_key_encryption_algorithm_name",
)?,
};
if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
entries[idx as usize].values.push(value);
} else {
let idx = u16::try_from(entries.len()).map_err(|_| {
Error::Protocol(
"sp_describe_parameter_encryption returned too many CEKs".into(),
)
})?;
ordinal_to_index.insert(key_ordinal, idx);
entries.push(CekTableEntry {
database_id: describe_int(row, 1, "database_id")? as u32,
cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
cek_md_version: describe_md_version(row, 4)?,
values: vec![value],
});
}
}
let cek_table = CekTable { entries };
let rs2_cols = result_sets[1].columns().len();
if rs2_cols < Self::RS2_MIN_COLS {
return Err(Error::Protocol(format!(
"sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
Self::RS2_MIN_COLS
)));
}
let rs2_rows = result_sets[1].collect_all()?;
let mut parameters = HashMap::new();
for row in &rs2_rows {
let name = describe_nvarchar(row, 1, "parameter_name")?;
let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
if encryption_type_byte == 0 {
continue;
}
let encryption_type =
EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
Error::Protocol(format!(
"sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
))
})?;
let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
let normalization_rule_version =
describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
Error::Protocol(format!(
"sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
))
})?;
parameters.insert(
name,
ParameterCryptoInfo {
cek_ordinal,
encryption_type,
algorithm_id,
normalization_rule_version,
},
);
}
Ok(Self {
cek_table,
parameters,
})
}
}
#[cfg(feature = "always-encrypted")]
fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
match row.get_raw(idx) {
Some(SqlValue::Int(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "int", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
match row.get_raw(idx) {
Some(SqlValue::TinyInt(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
match row.get_raw(idx) {
Some(SqlValue::String(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
match row.get_raw(idx) {
Some(SqlValue::Binary(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
match row.get_raw(idx) {
Some(SqlValue::Binary(b)) if b.len() == 8 => {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&b[..8]);
Ok(u64::from_le_bytes(bytes))
}
other => Err(describe_type_error(
"column_encryption_key_metadata_version",
idx,
"binary(8)",
other.as_ref(),
)),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
let got = got.map_or("missing", SqlValue::type_name);
Error::Protocol(format!(
"sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
))
}
#[cfg(feature = "always-encrypted")]
pub fn normalize_for_encryption(
value: &SqlValue,
param_type: Option<mssql_types::EncryptedParamType>,
) -> Result<Vec<u8>, EncryptionError> {
if let (Some(mssql_types::EncryptedParamType::Char { .. }), SqlValue::String(s)) =
(param_type, value)
{
let (encoded, _, had_errors) = encoding_rs::WINDOWS_1252.encode(s);
if had_errors {
return Err(EncryptionError::EncryptionFailed(
"char value contains characters not representable in Windows-1252 \
(the char column code page); use nchar for non-Latin text"
.to_string(),
));
}
return Ok(encoded.into_owned());
}
#[cfg(feature = "chrono")]
{
use mssql_types::EncryptedParamType as E;
match (param_type, value) {
(Some(E::Time { scale }), SqlValue::Time(t)) => return normalize_ae_time(*t, scale),
(Some(E::DateTime2 { scale }), SqlValue::DateTime(dt)) => {
return normalize_ae_datetime2(*dt, scale);
}
(Some(E::DateTimeOffset { scale }), SqlValue::DateTimeOffset(dto)) => {
return normalize_ae_datetimeoffset(*dto, scale);
}
(Some(E::DateTime), SqlValue::DateTime(dt)) => {
let mut buf = bytes::BytesMut::with_capacity(8);
mssql_types::__private::encode_datetime_legacy(*dt, &mut buf);
return Ok(buf.to_vec());
}
_ => {}
}
}
match value {
SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
SqlValue::Binary(b) => Ok(b.to_vec()),
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
let b = u.as_bytes();
Ok(vec![
b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
b[13], b[14], b[15],
])
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
use chrono::Datelike;
let days = (d.num_days_from_ce() - 1) as u32;
Ok(days.to_le_bytes()[..3].to_vec())
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
let mut out = Vec::with_capacity(17);
out.push(u8::from(!d.is_sign_negative()));
out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
Ok(out)
}
#[cfg(feature = "decimal")]
SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
let cents = money_cents(d)?;
let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
out.extend_from_slice(&(cents as u32).to_le_bytes());
Ok(out)
}
#[cfg(feature = "chrono")]
SqlValue::SmallDateTime(dt) => {
let mut buf = bytes::BytesMut::with_capacity(4);
mssql_types::__private::encode_smalldatetime(*dt, &mut buf).map_err(|e| {
EncryptionError::UnsupportedOperation(format!("SMALLDATETIME: {e}"))
})?;
Ok(buf.to_vec())
}
other => Err(EncryptionError::UnsupportedOperation(format!(
"Always Encrypted parameter encryption is not yet implemented for {}",
other.type_name()
))),
}
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
fn ae_date_bytes(d: chrono::NaiveDate) -> [u8; 3] {
use chrono::Datelike;
let days = (d.num_days_from_ce() - 1) as u32;
let b = days.to_le_bytes();
[b[0], b[1], b[2]]
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
fn normalize_ae_time(t: chrono::NaiveTime, scale: u8) -> Result<Vec<u8>, EncryptionError> {
use chrono::Timelike;
if scale > 7 {
return Err(EncryptionError::UnsupportedOperation(format!(
"time scale {scale} out of range (0–7)"
)));
}
let nanos =
u64::from(t.num_seconds_from_midnight()) * 1_000_000_000 + u64::from(t.nanosecond());
let ticks7 = nanos / 100;
let quantum = 10u64.pow(7 - u32::from(scale));
let quantized = (ticks7 / quantum) * quantum;
Ok(quantized.to_le_bytes()[..5].to_vec())
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
fn normalize_ae_datetime2(
dt: chrono::NaiveDateTime,
scale: u8,
) -> Result<Vec<u8>, EncryptionError> {
let mut out = normalize_ae_time(dt.time(), scale)?;
out.extend_from_slice(&ae_date_bytes(dt.date()));
Ok(out)
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
fn normalize_ae_datetimeoffset(
dto: chrono::DateTime<chrono::FixedOffset>,
scale: u8,
) -> Result<Vec<u8>, EncryptionError> {
use chrono::Offset;
let utc = dto.naive_utc();
let mut out = normalize_ae_time(utc.time(), scale)?;
out.extend_from_slice(&ae_date_bytes(utc.date()));
let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
out.extend_from_slice(&offset_minutes.to_le_bytes());
Ok(out)
}
#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
let mantissa = value.mantissa();
let scale = value.scale();
let cents: i128 = if scale <= 4 {
mantissa
.checked_mul(10_i128.pow(4 - scale))
.ok_or_else(|| {
EncryptionError::UnsupportedOperation("MONEY value out of range".into())
})?
} else {
mantissa / 10_i128.pow(scale - 4)
};
i64::try_from(cents)
.map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_matches_dotnet() {
use bytes::Bytes;
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
let enc = AeadEncryptor::new(&cek).unwrap();
for (value, reference) in [
(
SqlValue::Int(42),
"01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
),
(
SqlValue::String("Ada".to_string()),
"01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
),
(
SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
"01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
),
] {
let norm = normalize_for_encryption(&value, None).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_rejects_unnormalizable_value() {
assert!(normalize_for_encryption(&SqlValue::Null, None).is_err());
}
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_matches_dotnet_numeric() {
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
let enc = AeadEncryptor::new(&cek).unwrap();
for (value, reference) in [
(
SqlValue::BigInt(0x0102030405060708),
"01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
),
(
SqlValue::SmallInt(258),
"012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
),
(
SqlValue::TinyInt(200),
"01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
),
(
SqlValue::Bool(true),
"01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
),
(
SqlValue::Float(3.5),
"017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
),
(
SqlValue::Double(3.5),
"0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
),
] {
let norm = normalize_for_encryption(&value, None).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
#[test]
fn ae_normalization_matches_dotnet_uuid_date() {
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
let enc = AeadEncryptor::new(&cek).unwrap();
for (value, reference) in [
(
SqlValue::Uuid(
uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
),
"01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
),
(
SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
"0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
),
] {
let norm = normalize_for_encryption(&value, None).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
#[test]
fn ae_normalization_matches_dotnet_decimal_money() {
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
let enc = AeadEncryptor::new(&cek).unwrap();
let dec = rust_decimal::Decimal::new(123_456_789, 4); let money = rust_decimal::Decimal::new(123_400, 4);
for (value, reference) in [
(
SqlValue::Decimal(dec),
"018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
),
(
SqlValue::Money(money),
"01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
),
(
SqlValue::SmallMoney(money),
"01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
),
] {
let norm = normalize_for_encryption(&value, None).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
#[test]
fn ae_normalization_matches_dotnet_temporal() {
use mssql_types::EncryptedParamType as E;
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let day = chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap();
let dt = day.and_hms_nano_opt(13, 14, 15, 123_456_700).unwrap();
assert_eq!(
normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 7 }))
.unwrap(),
unhex("07c4aaf46e"),
);
assert_eq!(
normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 7 }))
.unwrap(),
unhex("07c4aaf46e8f460b"),
);
let dto = {
use chrono::TimeZone;
chrono::FixedOffset::east_opt(5 * 3600 + 30 * 60)
.unwrap()
.from_local_datetime(&dt)
.single()
.unwrap()
};
assert_eq!(
normalize_for_encryption(
&SqlValue::DateTimeOffset(dto),
Some(E::DateTimeOffset { scale: 7 })
)
.unwrap(),
unhex("0788f2da408f460b4a01"),
);
assert_eq!(
normalize_for_encryption(&SqlValue::Time(dt.time()), Some(E::Time { scale: 3 }))
.unwrap(),
unhex("30b2aaf46e"),
);
assert_eq!(
normalize_for_encryption(&SqlValue::DateTime(dt), Some(E::DateTime2 { scale: 3 }))
.unwrap(),
unhex("30b2aaf46e8f460b"),
);
assert_eq!(
normalize_for_encryption(
&SqlValue::DateTimeOffset(dto),
Some(E::DateTimeOffset { scale: 3 })
)
.unwrap(),
unhex("3076f2da408f460b4a01"),
);
let dt_legacy = day.and_hms_milli_opt(13, 14, 15, 123).unwrap();
assert_eq!(
normalize_for_encryption(&SqlValue::DateTime(dt_legacy), Some(E::DateTime)).unwrap(),
unhex("34b10000d925da00"),
);
let sdt = day.and_hms_opt(13, 14, 0).unwrap();
assert_eq!(
normalize_for_encryption(&SqlValue::SmallDateTime(sdt), None).unwrap(),
unhex("34b11a03"),
);
}
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_matches_dotnet_fixed_width() {
use mssql_types::EncryptedParamType as E;
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
assert_eq!(
normalize_for_encryption(
&SqlValue::String("Hello".to_string()),
Some(E::Char { length: 10 })
)
.unwrap(),
unhex("48656c6c6f"),
);
assert_eq!(
normalize_for_encryption(
&SqlValue::String("Hello".to_string()),
Some(E::NChar { length: 10 })
)
.unwrap(),
unhex("480065006c006c006f00"),
);
assert_eq!(
normalize_for_encryption(
&SqlValue::Binary(bytes::Bytes::from_static(&[1, 2, 3, 4, 5])),
Some(E::Binary { length: 10 })
)
.unwrap(),
unhex("0102030405"),
);
}
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_char_rejects_non_windows_1252() {
use mssql_types::EncryptedParamType as E;
let r = normalize_for_encryption(
&SqlValue::String("ä¸".to_string()),
Some(E::Char { length: 10 }),
);
assert!(
r.is_err(),
"non-Windows-1252 char must error, got {:?}",
r.map(|b| b.iter().map(|x| format!("{x:02x}")).collect::<String>())
);
assert_eq!(
normalize_for_encryption(
&SqlValue::String("é".to_string()),
Some(E::Char { length: 10 })
)
.unwrap(),
vec![0xE9],
);
}
#[test]
fn test_encryption_config_defaults() {
let config = EncryptionConfig::new();
assert!(config.enabled);
assert!(config.cache_ceks);
assert!(!config.is_ready()); }
#[cfg(feature = "always-encrypted")]
#[test]
fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
use crate::row::{Column, Row};
use crate::stream::ResultSet;
use bytes::Bytes;
fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
let cols: Vec<Column> = (0..n_cols)
.map(|i| Column::new(format!("c{i}"), i, "x"))
.collect();
let rows = rows
.into_iter()
.map(|vals| Row::from_values(cols.clone(), vals))
.collect();
ResultSet::new(cols, rows)
}
let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]);
let rs1 = rs(
9,
vec![
vec![
SqlValue::Int(1),
SqlValue::Int(7),
SqlValue::Int(56),
SqlValue::Int(1),
SqlValue::Binary(mdv1.clone()),
SqlValue::Binary(Bytes::from_static(b"env-a")),
SqlValue::String("IN_MEMORY_KEY_STORE".into()),
SqlValue::String("path-a".into()),
SqlValue::String("RSA_OAEP".into()),
],
vec![
SqlValue::Int(1),
SqlValue::Int(7),
SqlValue::Int(56),
SqlValue::Int(1),
SqlValue::Binary(mdv1),
SqlValue::Binary(Bytes::from_static(b"env-a2")),
SqlValue::String("PROV_2".into()),
SqlValue::String("path-a2".into()),
SqlValue::String("RSA_OAEP".into()),
],
vec![
SqlValue::Int(2),
SqlValue::Int(7),
SqlValue::Int(57),
SqlValue::Int(1),
SqlValue::Binary(mdv2),
SqlValue::Binary(Bytes::from_static(b"env-b")),
SqlValue::String("IN_MEMORY_KEY_STORE".into()),
SqlValue::String("path-b".into()),
SqlValue::String("RSA_OAEP".into()),
],
],
);
let rs2 = rs(
6,
vec![
vec![
SqlValue::Int(1),
SqlValue::String("@det".into()),
SqlValue::TinyInt(2),
SqlValue::TinyInt(1),
SqlValue::Int(1),
SqlValue::TinyInt(1),
],
vec![
SqlValue::Int(2),
SqlValue::String("@rand".into()),
SqlValue::TinyInt(2),
SqlValue::TinyInt(2),
SqlValue::Int(2),
SqlValue::TinyInt(1),
],
vec![
SqlValue::Int(3),
SqlValue::String("@plain".into()),
SqlValue::TinyInt(0),
SqlValue::TinyInt(0),
SqlValue::Int(0),
SqlValue::TinyInt(0),
],
],
);
let mut sets = vec![rs1, rs2];
let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
assert_eq!(info.cek_table.len(), 2);
let e0 = info.cek_table.get(0).unwrap();
assert_eq!(e0.cek_id, 56);
assert_eq!(e0.cek_md_version, 1);
assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
let e1 = info.cek_table.get(1).unwrap();
assert_eq!(e1.cek_id, 57);
assert_eq!(e1.cek_md_version, 255);
let det = info.get_parameter("@det").unwrap();
assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
assert_eq!(det.algorithm_id, 2);
assert_eq!(det.normalization_rule_version, 1);
assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
let rand = info.get_parameter("@rand").unwrap();
assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
assert_eq!(
rand.cek_ordinal, 1,
"server ordinal 2 -> positional index 1"
);
assert!(info.get_parameter("@plain").is_none());
assert_eq!(info.parameters.len(), 2);
}
#[cfg(feature = "always-encrypted")]
#[test]
fn parse_describe_result_sets_rejects_missing_result_set() {
use crate::row::{Column, Row};
use crate::stream::ResultSet;
let cols: Vec<Column> = (0..9)
.map(|i| Column::new(format!("c{i}"), i, "x"))
.collect();
let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
}
}