use std::collections::HashMap;
use mssql_auth::KeyStoreProvider;
use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
#[cfg(feature = "always-encrypted")]
use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
#[cfg(feature = "always-encrypted")]
use mssql_types::SqlValue;
#[cfg(feature = "always-encrypted")]
use std::sync::Arc;
#[cfg(feature = "always-encrypted")]
use crate::{Error, row::Row, stream::ResultSet};
#[cfg(feature = "always-encrypted")]
use tds_protocol::crypto::CekValue;
#[derive(Default)]
pub struct EncryptionConfig {
pub enabled: bool,
providers: Vec<Box<dyn KeyStoreProvider>>,
pub cache_ceks: bool,
}
impl EncryptionConfig {
#[must_use]
pub fn new() -> Self {
Self {
enabled: true,
providers: Vec::new(),
cache_ceks: true,
}
}
pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
self.providers.push(Box::new(provider));
}
#[must_use]
pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
self.register_provider(provider);
self
}
#[must_use]
pub fn with_cek_caching(mut self, enabled: bool) -> Self {
self.cache_ceks = enabled;
self
}
pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
self.providers
.iter()
.find(|p| p.provider_name() == name)
.map(|p| p.as_ref())
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.enabled && !self.providers.is_empty()
}
}
impl std::fmt::Debug for EncryptionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionConfig")
.field("enabled", &self.enabled)
.field("provider_count", &self.providers.len())
.field("cache_ceks", &self.cache_ceks)
.finish()
}
}
#[cfg(feature = "always-encrypted")]
pub struct EncryptionContext {
config: std::sync::Arc<EncryptionConfig>,
cek_cache: CekCache,
cache_enabled: bool,
}
#[cfg(feature = "always-encrypted")]
impl EncryptionContext {
pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
let cache_enabled = config.cache_ceks;
Self {
config,
cek_cache: CekCache::new(),
cache_enabled,
}
}
pub fn new(config: EncryptionConfig) -> Self {
Self::from_arc(std::sync::Arc::new(config))
}
pub async fn get_encryptor(
&self,
cek_entry: &CekTableEntry,
) -> Result<Arc<AeadEncryptor>, EncryptionError> {
let cache_key = CekCacheKey::new(
cek_entry.database_id,
cek_entry.cek_id,
cek_entry.cek_version,
);
if self.cache_enabled {
if let Some(encryptor) = self.cek_cache.get(&cache_key) {
return Ok(encryptor);
}
}
let cek_value = cek_entry
.primary_value()
.ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
let provider = self
.config
.get_provider(&cek_value.key_store_provider_name)
.ok_or_else(|| {
EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
})?;
let decrypted_cek = provider
.decrypt_cek(
&cek_value.cmk_path,
&cek_value.encryption_algorithm,
&cek_value.encrypted_value,
)
.await?;
if self.cache_enabled {
self.cek_cache.insert(cache_key, decrypted_cek)
} else {
Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
}
}
pub async fn encrypt_value(
&self,
plaintext: &[u8],
cek_entry: &CekTableEntry,
encryption_type: EncryptionTypeWire,
) -> Result<Vec<u8>, EncryptionError> {
let encryptor = self.get_encryptor(cek_entry).await?;
let enc_type = match encryption_type {
EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
_ => {
return Err(EncryptionError::UnsupportedOperation(format!(
"unsupported encryption type: {encryption_type:?}"
)));
}
};
encryptor.encrypt(plaintext, enc_type)
}
pub async fn decrypt_value(
&self,
ciphertext: &[u8],
cek_entry: &CekTableEntry,
) -> Result<Vec<u8>, EncryptionError> {
let encryptor = self.get_encryptor(cek_entry).await?;
encryptor.decrypt(ciphertext)
}
pub fn clear_cache(&self) {
self.cek_cache.clear();
}
pub fn has_provider(&self, name: &str) -> bool {
self.config.get_provider(name).is_some()
}
}
#[cfg(feature = "always-encrypted")]
impl std::fmt::Debug for EncryptionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionContext")
.field("provider_count", &self.config.providers.len())
.field("cache_entries", &self.cek_cache.len())
.field("cache_enabled", &self.cache_enabled)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ResultSetEncryptionInfo {
pub cek_table: CekTable,
pub column_crypto: Vec<Option<CryptoMetadata>>,
}
impl ResultSetEncryptionInfo {
pub fn new(cek_table: CekTable, column_count: usize) -> Self {
Self {
cek_table,
column_crypto: vec![None; column_count],
}
}
pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
if ordinal < self.column_crypto.len() {
self.column_crypto[ordinal] = Some(metadata);
}
}
pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
self.cek_table.get(crypto.cek_table_ordinal)
}
pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
self.column_crypto
.get(ordinal)
.map(|c| c.is_some())
.unwrap_or(false)
}
pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
self.column_crypto
.get(ordinal)?
.as_ref()
.map(|c| c.encryption_type)
}
}
#[derive(Debug, Clone)]
pub struct ParameterEncryptionInfo {
pub cek_table: CekTable,
pub parameters: HashMap<String, ParameterCryptoInfo>,
}
impl ParameterEncryptionInfo {
pub fn new() -> Self {
Self {
cek_table: CekTable::new(),
parameters: HashMap::new(),
}
}
pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
self.parameters.insert(name, info);
}
pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
self.parameters.get(name)
}
pub fn needs_encryption(&self, name: &str) -> bool {
self.parameters.contains_key(name)
}
}
impl Default for ParameterEncryptionInfo {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ParameterCryptoInfo {
pub cek_ordinal: u16,
pub encryption_type: EncryptionTypeWire,
pub algorithm_id: u8,
pub normalization_rule_version: u8,
}
impl ParameterCryptoInfo {
pub fn new(
cek_ordinal: u16,
encryption_type: EncryptionTypeWire,
algorithm_id: u8,
normalization_rule_version: u8,
) -> Self {
Self {
cek_ordinal,
encryption_type,
algorithm_id,
normalization_rule_version,
}
}
}
#[cfg(feature = "always-encrypted")]
impl ParameterEncryptionInfo {
const RS1_MIN_COLS: usize = 9;
const RS2_MIN_COLS: usize = 6;
pub(crate) fn from_describe_result_sets(result_sets: &mut [ResultSet]) -> Result<Self, Error> {
if result_sets.len() < 2 {
return Err(Error::Protocol(format!(
"sp_describe_parameter_encryption returned {} result set(s), expected 2",
result_sets.len()
)));
}
let rs1_cols = result_sets[0].columns().len();
if rs1_cols < Self::RS1_MIN_COLS {
return Err(Error::Protocol(format!(
"sp_describe_parameter_encryption result set 1 has {rs1_cols} columns, expected >= {}",
Self::RS1_MIN_COLS
)));
}
let rs1_rows = result_sets[0].collect_all()?;
let mut entries: Vec<CekTableEntry> = Vec::new();
let mut ordinal_to_index: HashMap<i32, u16> = HashMap::new();
for row in &rs1_rows {
let key_ordinal = describe_int(row, 0, "column_encryption_key_ordinal")?;
let value = CekValue {
encrypted_value: describe_varbinary(
row,
5,
"column_encryption_key_encrypted_value",
)?,
key_store_provider_name: describe_nvarchar(
row,
6,
"column_master_key_store_provider_name",
)?,
cmk_path: describe_nvarchar(row, 7, "column_master_key_path")?,
encryption_algorithm: describe_nvarchar(
row,
8,
"column_encryption_key_encryption_algorithm_name",
)?,
};
if let Some(&idx) = ordinal_to_index.get(&key_ordinal) {
entries[idx as usize].values.push(value);
} else {
let idx = u16::try_from(entries.len()).map_err(|_| {
Error::Protocol(
"sp_describe_parameter_encryption returned too many CEKs".into(),
)
})?;
ordinal_to_index.insert(key_ordinal, idx);
entries.push(CekTableEntry {
database_id: describe_int(row, 1, "database_id")? as u32,
cek_id: describe_int(row, 2, "column_encryption_key_id")? as u32,
cek_version: describe_int(row, 3, "column_encryption_key_version")? as u32,
cek_md_version: describe_md_version(row, 4)?,
values: vec![value],
});
}
}
let cek_table = CekTable { entries };
let rs2_cols = result_sets[1].columns().len();
if rs2_cols < Self::RS2_MIN_COLS {
return Err(Error::Protocol(format!(
"sp_describe_parameter_encryption result set 2 has {rs2_cols} columns, expected >= {}",
Self::RS2_MIN_COLS
)));
}
let rs2_rows = result_sets[1].collect_all()?;
let mut parameters = HashMap::new();
for row in &rs2_rows {
let name = describe_nvarchar(row, 1, "parameter_name")?;
let encryption_type_byte = describe_tinyint(row, 3, "column_encryption_type")?;
if encryption_type_byte == 0 {
continue;
}
let encryption_type =
EncryptionTypeWire::from_u8(encryption_type_byte).ok_or_else(|| {
Error::Protocol(format!(
"sp_describe_parameter_encryption: invalid column_encryption_type {encryption_type_byte} for {name}"
))
})?;
let algorithm_id = describe_tinyint(row, 2, "column_encryption_algorithm")?;
let server_ordinal = describe_int(row, 4, "column_encryption_key_ordinal")?;
let normalization_rule_version =
describe_tinyint(row, 5, "column_encryption_normalization_rule_version")?;
let cek_ordinal = *ordinal_to_index.get(&server_ordinal).ok_or_else(|| {
Error::Protocol(format!(
"sp_describe_parameter_encryption: parameter {name} references CEK ordinal {server_ordinal} absent from the CEK table"
))
})?;
parameters.insert(
name,
ParameterCryptoInfo {
cek_ordinal,
encryption_type,
algorithm_id,
normalization_rule_version,
},
);
}
Ok(Self {
cek_table,
parameters,
})
}
}
#[cfg(feature = "always-encrypted")]
fn describe_int(row: &Row, idx: usize, col: &str) -> Result<i32, Error> {
match row.get_raw(idx) {
Some(SqlValue::Int(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "int", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_tinyint(row: &Row, idx: usize, col: &str) -> Result<u8, Error> {
match row.get_raw(idx) {
Some(SqlValue::TinyInt(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "tinyint", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_nvarchar(row: &Row, idx: usize, col: &str) -> Result<String, Error> {
match row.get_raw(idx) {
Some(SqlValue::String(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "nvarchar", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_varbinary(row: &Row, idx: usize, col: &str) -> Result<bytes::Bytes, Error> {
match row.get_raw(idx) {
Some(SqlValue::Binary(v)) => Ok(v),
other => Err(describe_type_error(col, idx, "varbinary", other.as_ref())),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_md_version(row: &Row, idx: usize) -> Result<u64, Error> {
match row.get_raw(idx) {
Some(SqlValue::Binary(b)) if b.len() == 8 => {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&b[..8]);
Ok(u64::from_le_bytes(bytes))
}
other => Err(describe_type_error(
"column_encryption_key_metadata_version",
idx,
"binary(8)",
other.as_ref(),
)),
}
}
#[cfg(feature = "always-encrypted")]
fn describe_type_error(col: &str, idx: usize, expected: &str, got: Option<&SqlValue>) -> Error {
let got = got.map_or("missing", SqlValue::type_name);
Error::Protocol(format!(
"sp_describe_parameter_encryption column {col} (#{idx}): expected {expected}, got {got}"
))
}
#[cfg(feature = "always-encrypted")]
pub fn normalize_for_encryption(value: &SqlValue) -> Result<Vec<u8>, EncryptionError> {
match value {
SqlValue::Bool(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::TinyInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::SmallInt(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::Int(v) => Ok(i64::from(*v).to_le_bytes().to_vec()),
SqlValue::BigInt(v) => Ok(v.to_le_bytes().to_vec()),
SqlValue::Float(v) => Ok(v.to_le_bytes().to_vec()),
SqlValue::Double(v) => Ok(v.to_le_bytes().to_vec()),
SqlValue::String(s) => Ok(s.encode_utf16().flat_map(u16::to_le_bytes).collect()),
SqlValue::Binary(b) => Ok(b.to_vec()),
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
let b = u.as_bytes();
Ok(vec![
b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
b[13], b[14], b[15],
])
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
use chrono::Datelike;
let days = (d.num_days_from_ce() - 1) as u32;
Ok(days.to_le_bytes()[..3].to_vec())
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
let mut out = Vec::with_capacity(17);
out.push(u8::from(!d.is_sign_negative()));
out.extend_from_slice(&d.mantissa().unsigned_abs().to_le_bytes());
Ok(out)
}
#[cfg(feature = "decimal")]
SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
let cents = money_cents(d)?;
let mut out = ((cents >> 32) as i32).to_le_bytes().to_vec();
out.extend_from_slice(&(cents as u32).to_le_bytes());
Ok(out)
}
other => Err(EncryptionError::UnsupportedOperation(format!(
"Always Encrypted parameter encryption is not yet implemented for {}",
other.type_name()
))),
}
}
#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
fn money_cents(value: &rust_decimal::Decimal) -> Result<i64, EncryptionError> {
let mantissa = value.mantissa();
let scale = value.scale();
let cents: i128 = if scale <= 4 {
mantissa
.checked_mul(10_i128.pow(4 - scale))
.ok_or_else(|| {
EncryptionError::UnsupportedOperation("MONEY value out of range".into())
})?
} else {
mantissa / 10_i128.pow(scale - 4)
};
i64::try_from(cents)
.map_err(|_| EncryptionError::UnsupportedOperation("MONEY value out of range".into()))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_matches_dotnet() {
use bytes::Bytes;
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("B59D9F2C96784C232D53AB273D257DC79B7D2355BB82B1EC7054CE25E25F7B44");
let enc = AeadEncryptor::new(&cek).unwrap();
for (value, reference) in [
(
SqlValue::Int(42),
"01102FC5DEC5D3E463A8F4BDF512AA74E6AB953BA9A2F3F9A98CD18446B007DE5A6E2A1D1EB775035EA189CA5160A935CE093CAA9BB7E9233BB333AADEE86FDE1D",
),
(
SqlValue::String("Ada".to_string()),
"01BFAC40E6DA541ACEFAD8ECF5598DB77B0C5349CFACBC3C9221C01B6037E593B78E8F398F620F837BD6A4A2B644125C4188DF278B94479B2218466D91107FE417",
),
(
SqlValue::Binary(Bytes::from_static(&[0x01, 0x02, 0x03])),
"01ADE71457495F00FC9A16456F1B1EECB901D88DE97887025C189B1C4432E02071AB7594C48518CA5621E90165FAE337475B4CF3A3D00EF2D862FB0473713DF1E1",
),
] {
let norm = normalize_for_encryption(&value).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_rejects_unnormalizable_value() {
assert!(normalize_for_encryption(&SqlValue::Null).is_err());
}
#[cfg(feature = "always-encrypted")]
#[test]
fn ae_normalization_matches_dotnet_numeric() {
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
let enc = AeadEncryptor::new(&cek).unwrap();
for (value, reference) in [
(
SqlValue::BigInt(0x0102030405060708),
"01E765FC4696660028BFD48FCAEAED81E0EB423CFF433CA97F1B2FF02F70744E7265C2AE73CAA562FFA98AF98CB1D3EF6A4649B3640359E1DB7D170C80E639DA68",
),
(
SqlValue::SmallInt(258),
"012545AB817E1AEBDCEE1C00AEBFF3A013CAD20E0377BEFDD9186C263F8D1A909C313A753996F1B5E4A4AE17E901F6F781DCA707544766995D339601CA414063A0",
),
(
SqlValue::TinyInt(200),
"01A97C33480277D16FFAEDA9068173D4173378542F2887EBCD31CDEEEB116BD59D48F9D459BDDCABAE469E891B4F82AA3D283440CA1B5E9FFC150F9D0AE54EC21E",
),
(
SqlValue::Bool(true),
"01DDE18564051D630EE026331BCCAFC8F4122CC3919F81459F37D9C0E0C64A5317FCA08660FE5FC855917B97B72013F25B85ADD14ADDD7D5ED022EB1297FF29A7E",
),
(
SqlValue::Float(3.5),
"017A452760E7BA7AA6A716F6707F55D9C3A81683C04A6B561B13AC1D8A848E93E239BB922EE3EE628B6D0081A590BB11747CC25D216240FB10171A0FA3B99A2DB3",
),
(
SqlValue::Double(3.5),
"0171611557351FBC4561EBF0B9C98E0DC38AD2BD3E2C1D1E82F185D7E67D0425E506D11DD67BA3EB38F34FB01A8FCEF7E4B9A7256944334A521526613CFF6C8C5F",
),
] {
let norm = normalize_for_encryption(&value).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(all(feature = "always-encrypted", feature = "uuid", feature = "chrono"))]
#[test]
fn ae_normalization_matches_dotnet_uuid_date() {
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("9590E42A8A6C8F13B5D09B8D5A128EF8B3A4A10301C7AF24AFC62ED0E02342F7");
let enc = AeadEncryptor::new(&cek).unwrap();
for (value, reference) in [
(
SqlValue::Uuid(
uuid::Uuid::parse_str("01020304-0506-0708-090a-0b0c0d0e0f10").unwrap(),
),
"01F58635AA18692D68BDF551ECDD7AC3A56682D3F91F111F8D8F36D5425C405A8F6AB3ED3C3666444478476BD65FF40DC83F6831F502826AFEEC3116F71A7A2020CCD254F4BA28FCDC0F96BA2E5264AE9E",
),
(
SqlValue::Date(chrono::NaiveDate::from_ymd_opt(2024, 3, 15).unwrap()),
"0188B4F75A1F4BDA53C9CDDC1918C09CB57F68E13F5560F1F1D7168FE70707337B1156A97915B244F3C03D3E7352882A599511BD243471FD03683F371CF44E4B76",
),
] {
let norm = normalize_for_encryption(&value).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[cfg(all(feature = "always-encrypted", feature = "decimal"))]
#[test]
fn ae_normalization_matches_dotnet_decimal_money() {
fn unhex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect()
}
let cek = unhex("CBFB5AE21FB517C65DA0C6E8E11969C630798E473EF5827A70398012DF1D4B9E");
let enc = AeadEncryptor::new(&cek).unwrap();
let dec = rust_decimal::Decimal::new(123_456_789, 4); let money = rust_decimal::Decimal::new(123_400, 4);
for (value, reference) in [
(
SqlValue::Decimal(dec),
"018FAE46024B9B406C23600E6A9C694F9A9B39B785A995689EBE19437BA7E75768011A035A5B54B5E495512EBB46AE1146130940A0D0D834D61AA89B5AD9F71FFAF6EEEAE77E4856BA2AA5E016E2950A8D",
),
(
SqlValue::Money(money),
"01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
),
(
SqlValue::SmallMoney(money),
"01B4CE4CAD8D6B241A1555C377A0ADD4C79424DD5162F710D116594F725C1BAB015169A0C7716076EEC90E013519B961DEF427BFC32462D9E45D166C791B73F793",
),
] {
let norm = normalize_for_encryption(&value).unwrap();
let cipher = enc
.encrypt(&norm, mssql_auth::EncryptionType::Deterministic)
.unwrap();
assert_eq!(
cipher,
unhex(reference),
"ciphertext for {} must match Microsoft.Data.SqlClient",
value.type_name()
);
}
}
#[test]
fn test_encryption_config_defaults() {
let config = EncryptionConfig::new();
assert!(config.enabled);
assert!(config.cache_ceks);
assert!(!config.is_ready()); }
#[test]
fn test_result_set_encryption_info() {
let cek_table = CekTable::new();
let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
assert!(!info.is_column_encrypted(0));
assert!(!info.is_column_encrypted(1));
assert!(!info.is_column_encrypted(2));
let metadata = CryptoMetadata {
cek_table_ordinal: 0,
base_user_type: 0,
base_col_type: 0x26,
base_type_info: tds_protocol::token::TypeInfo::default(),
algorithm_id: 2,
encryption_type: EncryptionTypeWire::Deterministic,
normalization_version: 1,
};
info.set_column_crypto(1, metadata);
assert!(!info.is_column_encrypted(0));
assert!(info.is_column_encrypted(1));
assert!(!info.is_column_encrypted(2));
assert_eq!(
info.get_encryption_type(1),
Some(EncryptionTypeWire::Deterministic)
);
}
#[test]
fn test_parameter_encryption_info() {
let mut info = ParameterEncryptionInfo::new();
assert!(!info.needs_encryption("@p1"));
let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1);
info.add_parameter("@p1".to_string(), crypto);
assert!(info.needs_encryption("@p1"));
assert!(!info.needs_encryption("@p2"));
let param = info.get_parameter("@p1").unwrap();
assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
}
#[cfg(feature = "always-encrypted")]
#[test]
fn parse_describe_result_sets_groups_ceks_and_skips_plaintext() {
use crate::row::{Column, Row};
use crate::stream::ResultSet;
use bytes::Bytes;
fn rs(n_cols: usize, rows: Vec<Vec<SqlValue>>) -> ResultSet {
let cols: Vec<Column> = (0..n_cols)
.map(|i| Column::new(format!("c{i}"), i, "x"))
.collect();
let rows = rows
.into_iter()
.map(|vals| Row::from_values(cols.clone(), vals))
.collect();
ResultSet::new(cols, rows)
}
let mdv1 = Bytes::from_static(&[1, 0, 0, 0, 0, 0, 0, 0]); let mdv2 = Bytes::from_static(&[255, 0, 0, 0, 0, 0, 0, 0]);
let rs1 = rs(
9,
vec![
vec![
SqlValue::Int(1),
SqlValue::Int(7),
SqlValue::Int(56),
SqlValue::Int(1),
SqlValue::Binary(mdv1.clone()),
SqlValue::Binary(Bytes::from_static(b"env-a")),
SqlValue::String("IN_MEMORY_KEY_STORE".into()),
SqlValue::String("path-a".into()),
SqlValue::String("RSA_OAEP".into()),
],
vec![
SqlValue::Int(1),
SqlValue::Int(7),
SqlValue::Int(56),
SqlValue::Int(1),
SqlValue::Binary(mdv1),
SqlValue::Binary(Bytes::from_static(b"env-a2")),
SqlValue::String("PROV_2".into()),
SqlValue::String("path-a2".into()),
SqlValue::String("RSA_OAEP".into()),
],
vec![
SqlValue::Int(2),
SqlValue::Int(7),
SqlValue::Int(57),
SqlValue::Int(1),
SqlValue::Binary(mdv2),
SqlValue::Binary(Bytes::from_static(b"env-b")),
SqlValue::String("IN_MEMORY_KEY_STORE".into()),
SqlValue::String("path-b".into()),
SqlValue::String("RSA_OAEP".into()),
],
],
);
let rs2 = rs(
6,
vec![
vec![
SqlValue::Int(1),
SqlValue::String("@det".into()),
SqlValue::TinyInt(2),
SqlValue::TinyInt(1),
SqlValue::Int(1),
SqlValue::TinyInt(1),
],
vec![
SqlValue::Int(2),
SqlValue::String("@rand".into()),
SqlValue::TinyInt(2),
SqlValue::TinyInt(2),
SqlValue::Int(2),
SqlValue::TinyInt(1),
],
vec![
SqlValue::Int(3),
SqlValue::String("@plain".into()),
SqlValue::TinyInt(0),
SqlValue::TinyInt(0),
SqlValue::Int(0),
SqlValue::TinyInt(0),
],
],
);
let mut sets = vec![rs1, rs2];
let info = ParameterEncryptionInfo::from_describe_result_sets(&mut sets).unwrap();
assert_eq!(info.cek_table.len(), 2);
let e0 = info.cek_table.get(0).unwrap();
assert_eq!(e0.cek_id, 56);
assert_eq!(e0.cek_md_version, 1);
assert_eq!(e0.values.len(), 2, "two CMK-wrappings group under one CEK");
assert_eq!(e0.values[0].key_store_provider_name, "IN_MEMORY_KEY_STORE");
assert_eq!(e0.values[1].key_store_provider_name, "PROV_2");
let e1 = info.cek_table.get(1).unwrap();
assert_eq!(e1.cek_id, 57);
assert_eq!(e1.cek_md_version, 255);
let det = info.get_parameter("@det").unwrap();
assert_eq!(det.encryption_type, EncryptionTypeWire::Deterministic);
assert_eq!(det.algorithm_id, 2);
assert_eq!(det.normalization_rule_version, 1);
assert_eq!(det.cek_ordinal, 0, "server ordinal 1 -> positional index 0");
let rand = info.get_parameter("@rand").unwrap();
assert_eq!(rand.encryption_type, EncryptionTypeWire::Randomized);
assert_eq!(
rand.cek_ordinal, 1,
"server ordinal 2 -> positional index 1"
);
assert!(!info.needs_encryption("@plain"));
assert_eq!(info.parameters.len(), 2);
}
#[cfg(feature = "always-encrypted")]
#[test]
fn parse_describe_result_sets_rejects_missing_result_set() {
use crate::row::{Column, Row};
use crate::stream::ResultSet;
let cols: Vec<Column> = (0..9)
.map(|i| Column::new(format!("c{i}"), i, "x"))
.collect();
let mut sets = vec![ResultSet::new(cols, Vec::<Row>::new())];
assert!(ParameterEncryptionInfo::from_describe_result_sets(&mut sets).is_err());
}
}