1use std::collections::HashMap;
71
72use mssql_auth::KeyStoreProvider;
73use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
74
75#[cfg(feature = "always-encrypted")]
76use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
77#[cfg(feature = "always-encrypted")]
78use std::sync::Arc;
79
80#[derive(Default)]
82pub struct EncryptionConfig {
83 pub enabled: bool,
85 providers: Vec<Box<dyn KeyStoreProvider>>,
87 pub cache_ceks: bool,
89}
90
91impl EncryptionConfig {
92 #[must_use]
94 pub fn new() -> Self {
95 Self {
96 enabled: true,
97 providers: Vec::new(),
98 cache_ceks: true,
99 }
100 }
101
102 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
104 self.providers.push(Box::new(provider));
105 }
106
107 #[must_use]
109 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
110 self.register_provider(provider);
111 self
112 }
113
114 #[must_use]
116 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
117 self.cache_ceks = enabled;
118 self
119 }
120
121 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
123 self.providers
124 .iter()
125 .find(|p| p.provider_name() == name)
126 .map(|p| p.as_ref())
127 }
128
129 #[must_use]
131 pub fn is_ready(&self) -> bool {
132 self.enabled && !self.providers.is_empty()
133 }
134}
135
136impl std::fmt::Debug for EncryptionConfig {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("EncryptionConfig")
139 .field("enabled", &self.enabled)
140 .field("provider_count", &self.providers.len())
141 .field("cache_ceks", &self.cache_ceks)
142 .finish()
143 }
144}
145
146#[cfg(feature = "always-encrypted")]
155pub struct EncryptionContext {
156 config: std::sync::Arc<EncryptionConfig>,
160 cek_cache: CekCache,
162 cache_enabled: bool,
164}
165
166#[cfg(feature = "always-encrypted")]
167impl EncryptionContext {
168 pub fn from_arc(config: std::sync::Arc<EncryptionConfig>) -> Self {
174 let cache_enabled = config.cache_ceks;
175 Self {
176 config,
177 cek_cache: CekCache::new(),
178 cache_enabled,
179 }
180 }
181
182 pub fn new(config: EncryptionConfig) -> Self {
184 Self::from_arc(std::sync::Arc::new(config))
185 }
186
187 pub async fn get_encryptor(
194 &self,
195 cek_entry: &CekTableEntry,
196 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
197 let cache_key = CekCacheKey::new(
198 cek_entry.database_id,
199 cek_entry.cek_id,
200 cek_entry.cek_version,
201 );
202
203 if self.cache_enabled {
205 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
206 return Ok(encryptor);
207 }
208 }
209
210 let cek_value = cek_entry
212 .primary_value()
213 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
214
215 let provider = self
217 .config
218 .get_provider(&cek_value.key_store_provider_name)
219 .ok_or_else(|| {
220 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
221 })?;
222
223 let decrypted_cek = provider
225 .decrypt_cek(
226 &cek_value.cmk_path,
227 &cek_value.encryption_algorithm,
228 &cek_value.encrypted_value,
229 )
230 .await?;
231
232 if self.cache_enabled {
234 self.cek_cache.insert(cache_key, decrypted_cek)
235 } else {
236 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
238 }
239 }
240
241 pub async fn encrypt_value(
249 &self,
250 plaintext: &[u8],
251 cek_entry: &CekTableEntry,
252 encryption_type: EncryptionTypeWire,
253 ) -> Result<Vec<u8>, EncryptionError> {
254 let encryptor = self.get_encryptor(cek_entry).await?;
255
256 let enc_type = match encryption_type {
257 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
258 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
259 _ => {
260 return Err(EncryptionError::UnsupportedOperation(format!(
261 "unsupported encryption type: {encryption_type:?}"
262 )));
263 }
264 };
265
266 encryptor.encrypt(plaintext, enc_type)
267 }
268
269 pub async fn decrypt_value(
276 &self,
277 ciphertext: &[u8],
278 cek_entry: &CekTableEntry,
279 ) -> Result<Vec<u8>, EncryptionError> {
280 let encryptor = self.get_encryptor(cek_entry).await?;
281 encryptor.decrypt(ciphertext)
282 }
283
284 pub fn clear_cache(&self) {
288 self.cek_cache.clear();
289 }
290
291 pub fn has_provider(&self, name: &str) -> bool {
293 self.config.get_provider(name).is_some()
294 }
295}
296
297#[cfg(feature = "always-encrypted")]
298impl std::fmt::Debug for EncryptionContext {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.debug_struct("EncryptionContext")
301 .field("provider_count", &self.config.providers.len())
302 .field("cache_entries", &self.cek_cache.len())
303 .field("cache_enabled", &self.cache_enabled)
304 .finish()
305 }
306}
307
308#[derive(Debug, Clone)]
313pub struct ResultSetEncryptionInfo {
314 pub cek_table: CekTable,
316 pub column_crypto: Vec<Option<CryptoMetadata>>,
318}
319
320impl ResultSetEncryptionInfo {
321 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
323 Self {
324 cek_table,
325 column_crypto: vec![None; column_count],
326 }
327 }
328
329 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
331 if ordinal < self.column_crypto.len() {
332 self.column_crypto[ordinal] = Some(metadata);
333 }
334 }
335
336 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
338 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
339 self.cek_table.get(crypto.cek_table_ordinal)
340 }
341
342 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
344 self.column_crypto
345 .get(ordinal)
346 .map(|c| c.is_some())
347 .unwrap_or(false)
348 }
349
350 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
352 self.column_crypto
353 .get(ordinal)?
354 .as_ref()
355 .map(|c| c.encryption_type)
356 }
357}
358
359#[derive(Debug, Clone)]
364pub struct ParameterEncryptionInfo {
365 pub cek_table: CekTable,
367 pub parameters: HashMap<String, ParameterCryptoInfo>,
369}
370
371impl ParameterEncryptionInfo {
372 pub fn new() -> Self {
374 Self {
375 cek_table: CekTable::new(),
376 parameters: HashMap::new(),
377 }
378 }
379
380 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
382 self.parameters.insert(name, info);
383 }
384
385 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
387 self.parameters.get(name)
388 }
389
390 pub fn needs_encryption(&self, name: &str) -> bool {
392 self.parameters.contains_key(name)
393 }
394}
395
396impl Default for ParameterEncryptionInfo {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402#[derive(Debug, Clone)]
404pub struct ParameterCryptoInfo {
405 pub cek_ordinal: u16,
407 pub encryption_type: EncryptionTypeWire,
409 pub algorithm_id: u8,
411 pub column_ordinal: u16,
413 pub database_id: u32,
415}
416
417impl ParameterCryptoInfo {
418 pub fn new(
420 cek_ordinal: u16,
421 encryption_type: EncryptionTypeWire,
422 algorithm_id: u8,
423 column_ordinal: u16,
424 database_id: u32,
425 ) -> Self {
426 Self {
427 cek_ordinal,
428 encryption_type,
429 algorithm_id,
430 column_ordinal,
431 database_id,
432 }
433 }
434}
435
436#[cfg(test)]
437#[allow(clippy::unwrap_used, clippy::expect_used)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_encryption_config_defaults() {
443 let config = EncryptionConfig::new();
444 assert!(config.enabled);
445 assert!(config.cache_ceks);
446 assert!(!config.is_ready()); }
448
449 #[test]
450 fn test_result_set_encryption_info() {
451 let cek_table = CekTable::new();
452 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
453
454 assert!(!info.is_column_encrypted(0));
455 assert!(!info.is_column_encrypted(1));
456 assert!(!info.is_column_encrypted(2));
457
458 let metadata = CryptoMetadata {
459 cek_table_ordinal: 0,
460 base_user_type: 0,
461 base_col_type: 0x26,
462 base_type_info: tds_protocol::token::TypeInfo::default(),
463 algorithm_id: 2,
464 encryption_type: EncryptionTypeWire::Deterministic,
465 normalization_version: 1,
466 };
467
468 info.set_column_crypto(1, metadata);
469 assert!(!info.is_column_encrypted(0));
470 assert!(info.is_column_encrypted(1));
471 assert!(!info.is_column_encrypted(2));
472
473 assert_eq!(
474 info.get_encryption_type(1),
475 Some(EncryptionTypeWire::Deterministic)
476 );
477 }
478
479 #[test]
480 fn test_parameter_encryption_info() {
481 let mut info = ParameterEncryptionInfo::new();
482
483 assert!(!info.needs_encryption("@p1"));
484
485 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
486 info.add_parameter("@p1".to_string(), crypto);
487
488 assert!(info.needs_encryption("@p1"));
489 assert!(!info.needs_encryption("@p2"));
490
491 let param = info.get_parameter("@p1").unwrap();
492 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
493 }
494}