1use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57#[derive(Default)]
59pub struct EncryptionConfig {
60 pub enabled: bool,
62 providers: Vec<Box<dyn KeyStoreProvider>>,
64 pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 enabled: true,
74 providers: Vec::new(),
75 cache_ceks: true,
76 }
77 }
78
79 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81 self.providers.push(Box::new(provider));
82 }
83
84 #[must_use]
86 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87 self.register_provider(provider);
88 self
89 }
90
91 #[must_use]
93 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94 self.cache_ceks = enabled;
95 self
96 }
97
98 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100 self.providers
101 .iter()
102 .find(|p| p.provider_name() == name)
103 .map(|p| p.as_ref())
104 }
105
106 #[must_use]
108 pub fn is_ready(&self) -> bool {
109 self.enabled && !self.providers.is_empty()
110 }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("EncryptionConfig")
116 .field("enabled", &self.enabled)
117 .field("provider_count", &self.providers.len())
118 .field("cache_ceks", &self.cache_ceks)
119 .finish()
120 }
121}
122
123#[cfg(feature = "always-encrypted")]
128pub struct EncryptionContext {
129 providers: HashMap<String, Box<dyn KeyStoreProvider>>,
131 cek_cache: CekCache,
133 cache_enabled: bool,
135}
136
137#[cfg(feature = "always-encrypted")]
138impl EncryptionContext {
139 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
146 match std::sync::Arc::try_unwrap(config) {
147 Ok(owned) => Self::new(owned),
148 Err(_arc) => {
149 tracing::warn!(
152 "EncryptionConfig has multiple references; \
153 creating EncryptionContext without providers"
154 );
155 Self {
156 providers: std::collections::HashMap::new(),
157 cek_cache: CekCache::new(),
158 cache_enabled: true,
159 }
160 }
161 }
162 }
163
164 pub fn new(config: EncryptionConfig) -> Self {
166 let providers = config
167 .providers
168 .into_iter()
169 .map(|p| (p.provider_name().to_string(), p))
170 .collect();
171
172 Self {
173 providers,
174 cek_cache: CekCache::new(),
175 cache_enabled: config.cache_ceks,
176 }
177 }
178
179 pub async fn get_encryptor(
186 &self,
187 cek_entry: &CekTableEntry,
188 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
189 let cache_key = CekCacheKey::new(
190 cek_entry.database_id,
191 cek_entry.cek_id,
192 cek_entry.cek_version,
193 );
194
195 if self.cache_enabled {
197 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
198 return Ok(encryptor);
199 }
200 }
201
202 let cek_value = cek_entry
204 .primary_value()
205 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
206
207 let provider = self
209 .providers
210 .get(&cek_value.key_store_provider_name)
211 .ok_or_else(|| {
212 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
213 })?;
214
215 let decrypted_cek = provider
217 .decrypt_cek(
218 &cek_value.cmk_path,
219 &cek_value.encryption_algorithm,
220 &cek_value.encrypted_value,
221 )
222 .await?;
223
224 if self.cache_enabled {
226 self.cek_cache.insert(cache_key, decrypted_cek)
227 } else {
228 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
230 }
231 }
232
233 pub async fn encrypt_value(
241 &self,
242 plaintext: &[u8],
243 cek_entry: &CekTableEntry,
244 encryption_type: EncryptionTypeWire,
245 ) -> Result<Vec<u8>, EncryptionError> {
246 let encryptor = self.get_encryptor(cek_entry).await?;
247
248 let enc_type = match encryption_type {
249 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
250 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
251 _ => {
252 return Err(EncryptionError::UnsupportedOperation(format!(
253 "unsupported encryption type: {encryption_type:?}"
254 )));
255 }
256 };
257
258 encryptor.encrypt(plaintext, enc_type)
259 }
260
261 pub async fn decrypt_value(
268 &self,
269 ciphertext: &[u8],
270 cek_entry: &CekTableEntry,
271 ) -> Result<Vec<u8>, EncryptionError> {
272 let encryptor = self.get_encryptor(cek_entry).await?;
273 encryptor.decrypt(ciphertext)
274 }
275
276 pub fn clear_cache(&self) {
280 self.cek_cache.clear();
281 }
282
283 pub fn has_provider(&self, name: &str) -> bool {
285 self.providers.contains_key(name)
286 }
287}
288
289#[cfg(feature = "always-encrypted")]
290impl std::fmt::Debug for EncryptionContext {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 f.debug_struct("EncryptionContext")
293 .field("providers", &self.providers.keys().collect::<Vec<_>>())
294 .field("cache_entries", &self.cek_cache.len())
295 .field("cache_enabled", &self.cache_enabled)
296 .finish()
297 }
298}
299
300#[derive(Debug, Clone)]
305pub struct ResultSetEncryptionInfo {
306 pub cek_table: CekTable,
308 pub column_crypto: Vec<Option<CryptoMetadata>>,
310}
311
312impl ResultSetEncryptionInfo {
313 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
315 Self {
316 cek_table,
317 column_crypto: vec![None; column_count],
318 }
319 }
320
321 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
323 if ordinal < self.column_crypto.len() {
324 self.column_crypto[ordinal] = Some(metadata);
325 }
326 }
327
328 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
330 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
331 self.cek_table.get(crypto.cek_table_ordinal)
332 }
333
334 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
336 self.column_crypto
337 .get(ordinal)
338 .map(|c| c.is_some())
339 .unwrap_or(false)
340 }
341
342 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
344 self.column_crypto
345 .get(ordinal)?
346 .as_ref()
347 .map(|c| c.encryption_type)
348 }
349}
350
351#[derive(Debug, Clone)]
356pub struct ParameterEncryptionInfo {
357 pub cek_table: CekTable,
359 pub parameters: HashMap<String, ParameterCryptoInfo>,
361}
362
363impl ParameterEncryptionInfo {
364 pub fn new() -> Self {
366 Self {
367 cek_table: CekTable::new(),
368 parameters: HashMap::new(),
369 }
370 }
371
372 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
374 self.parameters.insert(name, info);
375 }
376
377 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
379 self.parameters.get(name)
380 }
381
382 pub fn needs_encryption(&self, name: &str) -> bool {
384 self.parameters.contains_key(name)
385 }
386}
387
388impl Default for ParameterEncryptionInfo {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct ParameterCryptoInfo {
397 pub cek_ordinal: u16,
399 pub encryption_type: EncryptionTypeWire,
401 pub algorithm_id: u8,
403 pub column_ordinal: u16,
405 pub database_id: u32,
407}
408
409impl ParameterCryptoInfo {
410 pub fn new(
412 cek_ordinal: u16,
413 encryption_type: EncryptionTypeWire,
414 algorithm_id: u8,
415 column_ordinal: u16,
416 database_id: u32,
417 ) -> Self {
418 Self {
419 cek_ordinal,
420 encryption_type,
421 algorithm_id,
422 column_ordinal,
423 database_id,
424 }
425 }
426}
427
428#[cfg(test)]
429#[allow(clippy::unwrap_used, clippy::expect_used)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_encryption_config_defaults() {
435 let config = EncryptionConfig::new();
436 assert!(config.enabled);
437 assert!(config.cache_ceks);
438 assert!(!config.is_ready()); }
440
441 #[test]
442 fn test_result_set_encryption_info() {
443 let cek_table = CekTable::new();
444 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
445
446 assert!(!info.is_column_encrypted(0));
447 assert!(!info.is_column_encrypted(1));
448 assert!(!info.is_column_encrypted(2));
449
450 let metadata = CryptoMetadata {
451 cek_table_ordinal: 0,
452 base_user_type: 0,
453 base_col_type: 0x26,
454 base_type_info: tds_protocol::token::TypeInfo::default(),
455 algorithm_id: 2,
456 encryption_type: EncryptionTypeWire::Deterministic,
457 normalization_version: 1,
458 };
459
460 info.set_column_crypto(1, metadata);
461 assert!(!info.is_column_encrypted(0));
462 assert!(info.is_column_encrypted(1));
463 assert!(!info.is_column_encrypted(2));
464
465 assert_eq!(
466 info.get_encryption_type(1),
467 Some(EncryptionTypeWire::Deterministic)
468 );
469 }
470
471 #[test]
472 fn test_parameter_encryption_info() {
473 let mut info = ParameterEncryptionInfo::new();
474
475 assert!(!info.needs_encryption("@p1"));
476
477 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
478 info.add_parameter("@p1".to_string(), crypto);
479
480 assert!(info.needs_encryption("@p1"));
481 assert!(!info.needs_encryption("@p2"));
482
483 let param = info.get_parameter("@p1").unwrap();
484 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
485 }
486}