use std::collections::HashMap;
use mssql_auth::KeyStoreProvider;
use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
#[cfg(feature = "always-encrypted")]
use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
#[cfg(feature = "always-encrypted")]
use std::sync::Arc;
#[derive(Default)]
pub struct EncryptionConfig {
pub enabled: bool,
providers: Vec<Box<dyn KeyStoreProvider>>,
pub cache_ceks: bool,
}
impl EncryptionConfig {
#[must_use]
pub fn new() -> Self {
Self {
enabled: true,
providers: Vec::new(),
cache_ceks: true,
}
}
pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
self.providers.push(Box::new(provider));
}
#[must_use]
pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
self.register_provider(provider);
self
}
#[must_use]
pub fn with_cek_caching(mut self, enabled: bool) -> Self {
self.cache_ceks = enabled;
self
}
pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
self.providers
.iter()
.find(|p| p.provider_name() == name)
.map(|p| p.as_ref())
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.enabled && !self.providers.is_empty()
}
}
impl std::fmt::Debug for EncryptionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionConfig")
.field("enabled", &self.enabled)
.field("provider_count", &self.providers.len())
.field("cache_ceks", &self.cache_ceks)
.finish()
}
}
#[cfg(feature = "always-encrypted")]
pub struct EncryptionContext {
config: std::sync::Arc<EncryptionConfig>,
cek_cache: CekCache,
cache_enabled: bool,
}
#[cfg(feature = "always-encrypted")]
impl EncryptionContext {
pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
let cache_enabled = config.cache_ceks;
Self {
config,
cek_cache: CekCache::new(),
cache_enabled,
}
}
pub fn new(config: EncryptionConfig) -> Self {
Self::from_arc(std::sync::Arc::new(config))
}
pub async fn get_encryptor(
&self,
cek_entry: &CekTableEntry,
) -> Result<Arc<AeadEncryptor>, EncryptionError> {
let cache_key = CekCacheKey::new(
cek_entry.database_id,
cek_entry.cek_id,
cek_entry.cek_version,
);
if self.cache_enabled {
if let Some(encryptor) = self.cek_cache.get(&cache_key) {
return Ok(encryptor);
}
}
let cek_value = cek_entry
.primary_value()
.ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
let provider = self
.config
.get_provider(&cek_value.key_store_provider_name)
.ok_or_else(|| {
EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
})?;
let decrypted_cek = provider
.decrypt_cek(
&cek_value.cmk_path,
&cek_value.encryption_algorithm,
&cek_value.encrypted_value,
)
.await?;
if self.cache_enabled {
self.cek_cache.insert(cache_key, decrypted_cek)
} else {
Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
}
}
pub async fn encrypt_value(
&self,
plaintext: &[u8],
cek_entry: &CekTableEntry,
encryption_type: EncryptionTypeWire,
) -> Result<Vec<u8>, EncryptionError> {
let encryptor = self.get_encryptor(cek_entry).await?;
let enc_type = match encryption_type {
EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
_ => {
return Err(EncryptionError::UnsupportedOperation(format!(
"unsupported encryption type: {encryption_type:?}"
)));
}
};
encryptor.encrypt(plaintext, enc_type)
}
pub async fn decrypt_value(
&self,
ciphertext: &[u8],
cek_entry: &CekTableEntry,
) -> Result<Vec<u8>, EncryptionError> {
let encryptor = self.get_encryptor(cek_entry).await?;
encryptor.decrypt(ciphertext)
}
pub fn clear_cache(&self) {
self.cek_cache.clear();
}
pub fn has_provider(&self, name: &str) -> bool {
self.config.get_provider(name).is_some()
}
}
#[cfg(feature = "always-encrypted")]
impl std::fmt::Debug for EncryptionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionContext")
.field("provider_count", &self.config.providers.len())
.field("cache_entries", &self.cek_cache.len())
.field("cache_enabled", &self.cache_enabled)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ResultSetEncryptionInfo {
pub cek_table: CekTable,
pub column_crypto: Vec<Option<CryptoMetadata>>,
}
impl ResultSetEncryptionInfo {
pub fn new(cek_table: CekTable, column_count: usize) -> Self {
Self {
cek_table,
column_crypto: vec![None; column_count],
}
}
pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
if ordinal < self.column_crypto.len() {
self.column_crypto[ordinal] = Some(metadata);
}
}
pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
self.cek_table.get(crypto.cek_table_ordinal)
}
pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
self.column_crypto
.get(ordinal)
.map(|c| c.is_some())
.unwrap_or(false)
}
pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
self.column_crypto
.get(ordinal)?
.as_ref()
.map(|c| c.encryption_type)
}
}
#[derive(Debug, Clone)]
pub struct ParameterEncryptionInfo {
pub cek_table: CekTable,
pub parameters: HashMap<String, ParameterCryptoInfo>,
}
impl ParameterEncryptionInfo {
pub fn new() -> Self {
Self {
cek_table: CekTable::new(),
parameters: HashMap::new(),
}
}
pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
self.parameters.insert(name, info);
}
pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
self.parameters.get(name)
}
pub fn needs_encryption(&self, name: &str) -> bool {
self.parameters.contains_key(name)
}
}
impl Default for ParameterEncryptionInfo {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ParameterCryptoInfo {
pub cek_ordinal: u16,
pub encryption_type: EncryptionTypeWire,
pub algorithm_id: u8,
pub column_ordinal: u16,
pub database_id: u32,
}
impl ParameterCryptoInfo {
pub fn new(
cek_ordinal: u16,
encryption_type: EncryptionTypeWire,
algorithm_id: u8,
column_ordinal: u16,
database_id: u32,
) -> Self {
Self {
cek_ordinal,
encryption_type,
algorithm_id,
column_ordinal,
database_id,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_encryption_config_defaults() {
let config = EncryptionConfig::new();
assert!(config.enabled);
assert!(config.cache_ceks);
assert!(!config.is_ready()); }
#[test]
fn test_result_set_encryption_info() {
let cek_table = CekTable::new();
let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
assert!(!info.is_column_encrypted(0));
assert!(!info.is_column_encrypted(1));
assert!(!info.is_column_encrypted(2));
let metadata = CryptoMetadata {
cek_table_ordinal: 0,
base_user_type: 0,
base_col_type: 0x26,
base_type_info: tds_protocol::token::TypeInfo::default(),
algorithm_id: 2,
encryption_type: EncryptionTypeWire::Deterministic,
normalization_version: 1,
};
info.set_column_crypto(1, metadata);
assert!(!info.is_column_encrypted(0));
assert!(info.is_column_encrypted(1));
assert!(!info.is_column_encrypted(2));
assert_eq!(
info.get_encryption_type(1),
Some(EncryptionTypeWire::Deterministic)
);
}
#[test]
fn test_parameter_encryption_info() {
let mut info = ParameterEncryptionInfo::new();
assert!(!info.needs_encryption("@p1"));
let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
info.add_parameter("@p1".to_string(), crypto);
assert!(info.needs_encryption("@p1"));
assert!(!info.needs_encryption("@p2"));
let param = info.get_parameter("@p1").unwrap();
assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
}
}