1use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57#[derive(Default)]
59pub struct EncryptionConfig {
60 pub enabled: bool,
62 providers: Vec<Box<dyn KeyStoreProvider>>,
64 pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 enabled: true,
74 providers: Vec::new(),
75 cache_ceks: true,
76 }
77 }
78
79 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81 self.providers.push(Box::new(provider));
82 }
83
84 #[must_use]
86 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87 self.register_provider(provider);
88 self
89 }
90
91 #[must_use]
93 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94 self.cache_ceks = enabled;
95 self
96 }
97
98 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100 self.providers
101 .iter()
102 .find(|p| p.provider_name() == name)
103 .map(|p| p.as_ref())
104 }
105
106 #[must_use]
108 pub fn is_ready(&self) -> bool {
109 self.enabled && !self.providers.is_empty()
110 }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("EncryptionConfig")
116 .field("enabled", &self.enabled)
117 .field("provider_count", &self.providers.len())
118 .field("cache_ceks", &self.cache_ceks)
119 .finish()
120 }
121}
122
123#[cfg(feature = "always-encrypted")]
132pub struct EncryptionContext {
133 config: std::sync::Arc<EncryptionConfig>,
137 cek_cache: CekCache,
139 cache_enabled: bool,
141}
142
143#[cfg(feature = "always-encrypted")]
144impl EncryptionContext {
145 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
151 let cache_enabled = config.cache_ceks;
152 Self {
153 config,
154 cek_cache: CekCache::new(),
155 cache_enabled,
156 }
157 }
158
159 pub fn new(config: EncryptionConfig) -> Self {
161 Self::from_arc(std::sync::Arc::new(config))
162 }
163
164 pub async fn get_encryptor(
171 &self,
172 cek_entry: &CekTableEntry,
173 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
174 let cache_key = CekCacheKey::new(
175 cek_entry.database_id,
176 cek_entry.cek_id,
177 cek_entry.cek_version,
178 );
179
180 if self.cache_enabled {
182 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
183 return Ok(encryptor);
184 }
185 }
186
187 let cek_value = cek_entry
189 .primary_value()
190 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
191
192 let provider = self
194 .config
195 .get_provider(&cek_value.key_store_provider_name)
196 .ok_or_else(|| {
197 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
198 })?;
199
200 let decrypted_cek = provider
202 .decrypt_cek(
203 &cek_value.cmk_path,
204 &cek_value.encryption_algorithm,
205 &cek_value.encrypted_value,
206 )
207 .await?;
208
209 if self.cache_enabled {
211 self.cek_cache.insert(cache_key, decrypted_cek)
212 } else {
213 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
215 }
216 }
217
218 pub async fn encrypt_value(
226 &self,
227 plaintext: &[u8],
228 cek_entry: &CekTableEntry,
229 encryption_type: EncryptionTypeWire,
230 ) -> Result<Vec<u8>, EncryptionError> {
231 let encryptor = self.get_encryptor(cek_entry).await?;
232
233 let enc_type = match encryption_type {
234 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
235 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
236 _ => {
237 return Err(EncryptionError::UnsupportedOperation(format!(
238 "unsupported encryption type: {encryption_type:?}"
239 )));
240 }
241 };
242
243 encryptor.encrypt(plaintext, enc_type)
244 }
245
246 pub async fn decrypt_value(
253 &self,
254 ciphertext: &[u8],
255 cek_entry: &CekTableEntry,
256 ) -> Result<Vec<u8>, EncryptionError> {
257 let encryptor = self.get_encryptor(cek_entry).await?;
258 encryptor.decrypt(ciphertext)
259 }
260
261 pub fn clear_cache(&self) {
265 self.cek_cache.clear();
266 }
267
268 pub fn has_provider(&self, name: &str) -> bool {
270 self.config.get_provider(name).is_some()
271 }
272}
273
274#[cfg(feature = "always-encrypted")]
275impl std::fmt::Debug for EncryptionContext {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.debug_struct("EncryptionContext")
278 .field("provider_count", &self.config.providers.len())
279 .field("cache_entries", &self.cek_cache.len())
280 .field("cache_enabled", &self.cache_enabled)
281 .finish()
282 }
283}
284
285#[derive(Debug, Clone)]
290pub struct ResultSetEncryptionInfo {
291 pub cek_table: CekTable,
293 pub column_crypto: Vec<Option<CryptoMetadata>>,
295}
296
297impl ResultSetEncryptionInfo {
298 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
300 Self {
301 cek_table,
302 column_crypto: vec![None; column_count],
303 }
304 }
305
306 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
308 if ordinal < self.column_crypto.len() {
309 self.column_crypto[ordinal] = Some(metadata);
310 }
311 }
312
313 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
315 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
316 self.cek_table.get(crypto.cek_table_ordinal)
317 }
318
319 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
321 self.column_crypto
322 .get(ordinal)
323 .map(|c| c.is_some())
324 .unwrap_or(false)
325 }
326
327 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
329 self.column_crypto
330 .get(ordinal)?
331 .as_ref()
332 .map(|c| c.encryption_type)
333 }
334}
335
336#[derive(Debug, Clone)]
341pub struct ParameterEncryptionInfo {
342 pub cek_table: CekTable,
344 pub parameters: HashMap<String, ParameterCryptoInfo>,
346}
347
348impl ParameterEncryptionInfo {
349 pub fn new() -> Self {
351 Self {
352 cek_table: CekTable::new(),
353 parameters: HashMap::new(),
354 }
355 }
356
357 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
359 self.parameters.insert(name, info);
360 }
361
362 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
364 self.parameters.get(name)
365 }
366
367 pub fn needs_encryption(&self, name: &str) -> bool {
369 self.parameters.contains_key(name)
370 }
371}
372
373impl Default for ParameterEncryptionInfo {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[derive(Debug, Clone)]
381pub struct ParameterCryptoInfo {
382 pub cek_ordinal: u16,
384 pub encryption_type: EncryptionTypeWire,
386 pub algorithm_id: u8,
388 pub column_ordinal: u16,
390 pub database_id: u32,
392}
393
394impl ParameterCryptoInfo {
395 pub fn new(
397 cek_ordinal: u16,
398 encryption_type: EncryptionTypeWire,
399 algorithm_id: u8,
400 column_ordinal: u16,
401 database_id: u32,
402 ) -> Self {
403 Self {
404 cek_ordinal,
405 encryption_type,
406 algorithm_id,
407 column_ordinal,
408 database_id,
409 }
410 }
411}
412
413#[cfg(test)]
414#[allow(clippy::unwrap_used, clippy::expect_used)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_encryption_config_defaults() {
420 let config = EncryptionConfig::new();
421 assert!(config.enabled);
422 assert!(config.cache_ceks);
423 assert!(!config.is_ready()); }
425
426 #[test]
427 fn test_result_set_encryption_info() {
428 let cek_table = CekTable::new();
429 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
430
431 assert!(!info.is_column_encrypted(0));
432 assert!(!info.is_column_encrypted(1));
433 assert!(!info.is_column_encrypted(2));
434
435 let metadata = CryptoMetadata {
436 cek_table_ordinal: 0,
437 base_user_type: 0,
438 base_col_type: 0x26,
439 base_type_info: tds_protocol::token::TypeInfo::default(),
440 algorithm_id: 2,
441 encryption_type: EncryptionTypeWire::Deterministic,
442 normalization_version: 1,
443 };
444
445 info.set_column_crypto(1, metadata);
446 assert!(!info.is_column_encrypted(0));
447 assert!(info.is_column_encrypted(1));
448 assert!(!info.is_column_encrypted(2));
449
450 assert_eq!(
451 info.get_encryption_type(1),
452 Some(EncryptionTypeWire::Deterministic)
453 );
454 }
455
456 #[test]
457 fn test_parameter_encryption_info() {
458 let mut info = ParameterEncryptionInfo::new();
459
460 assert!(!info.needs_encryption("@p1"));
461
462 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
463 info.add_parameter("@p1".to_string(), crypto);
464
465 assert!(info.needs_encryption("@p1"));
466 assert!(!info.needs_encryption("@p2"));
467
468 let param = info.get_parameter("@p1").unwrap();
469 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
470 }
471}