1use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57#[derive(Default)]
59pub struct EncryptionConfig {
60 pub enabled: bool,
62 providers: Vec<Box<dyn KeyStoreProvider>>,
64 pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 enabled: true,
74 providers: Vec::new(),
75 cache_ceks: true,
76 }
77 }
78
79 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81 self.providers.push(Box::new(provider));
82 }
83
84 #[must_use]
86 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87 self.register_provider(provider);
88 self
89 }
90
91 #[must_use]
93 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94 self.cache_ceks = enabled;
95 self
96 }
97
98 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100 self.providers
101 .iter()
102 .find(|p| p.provider_name() == name)
103 .map(|p| p.as_ref())
104 }
105
106 #[must_use]
108 pub fn is_ready(&self) -> bool {
109 self.enabled && !self.providers.is_empty()
110 }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("EncryptionConfig")
116 .field("enabled", &self.enabled)
117 .field("provider_count", &self.providers.len())
118 .field("cache_ceks", &self.cache_ceks)
119 .finish()
120 }
121}
122
123#[cfg(feature = "always-encrypted")]
128pub struct EncryptionContext {
129 providers: HashMap<String, Box<dyn KeyStoreProvider>>,
131 cek_cache: CekCache,
133 cache_enabled: bool,
135}
136
137#[cfg(feature = "always-encrypted")]
138impl EncryptionContext {
139 pub fn new(config: EncryptionConfig) -> Self {
141 let providers = config
142 .providers
143 .into_iter()
144 .map(|p| (p.provider_name().to_string(), p))
145 .collect();
146
147 Self {
148 providers,
149 cek_cache: CekCache::new(),
150 cache_enabled: config.cache_ceks,
151 }
152 }
153
154 pub async fn get_encryptor(
161 &self,
162 cek_entry: &CekTableEntry,
163 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
164 let cache_key = CekCacheKey::new(
165 cek_entry.database_id,
166 cek_entry.cek_id,
167 cek_entry.cek_version,
168 );
169
170 if self.cache_enabled {
172 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
173 return Ok(encryptor);
174 }
175 }
176
177 let cek_value = cek_entry
179 .primary_value()
180 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
181
182 let provider = self
184 .providers
185 .get(&cek_value.key_store_provider_name)
186 .ok_or_else(|| {
187 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
188 })?;
189
190 let decrypted_cek = provider
192 .decrypt_cek(
193 &cek_value.cmk_path,
194 &cek_value.encryption_algorithm,
195 &cek_value.encrypted_value,
196 )
197 .await?;
198
199 if self.cache_enabled {
201 self.cek_cache.insert(cache_key, decrypted_cek)
202 } else {
203 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
205 }
206 }
207
208 pub async fn encrypt_value(
216 &self,
217 plaintext: &[u8],
218 cek_entry: &CekTableEntry,
219 encryption_type: EncryptionTypeWire,
220 ) -> Result<Vec<u8>, EncryptionError> {
221 let encryptor = self.get_encryptor(cek_entry).await?;
222
223 let enc_type = match encryption_type {
224 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
225 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
226 };
227
228 encryptor.encrypt(plaintext, enc_type)
229 }
230
231 pub async fn decrypt_value(
238 &self,
239 ciphertext: &[u8],
240 cek_entry: &CekTableEntry,
241 ) -> Result<Vec<u8>, EncryptionError> {
242 let encryptor = self.get_encryptor(cek_entry).await?;
243 encryptor.decrypt(ciphertext)
244 }
245
246 pub fn clear_cache(&self) {
250 self.cek_cache.clear();
251 }
252
253 pub fn has_provider(&self, name: &str) -> bool {
255 self.providers.contains_key(name)
256 }
257}
258
259#[cfg(feature = "always-encrypted")]
260impl std::fmt::Debug for EncryptionContext {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("EncryptionContext")
263 .field("providers", &self.providers.keys().collect::<Vec<_>>())
264 .field("cache_entries", &self.cek_cache.len())
265 .field("cache_enabled", &self.cache_enabled)
266 .finish()
267 }
268}
269
270#[derive(Debug, Clone)]
275pub struct ResultSetEncryptionInfo {
276 pub cek_table: CekTable,
278 pub column_crypto: Vec<Option<CryptoMetadata>>,
280}
281
282impl ResultSetEncryptionInfo {
283 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
285 Self {
286 cek_table,
287 column_crypto: vec![None; column_count],
288 }
289 }
290
291 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
293 if ordinal < self.column_crypto.len() {
294 self.column_crypto[ordinal] = Some(metadata);
295 }
296 }
297
298 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
300 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
301 self.cek_table.get(crypto.cek_table_ordinal)
302 }
303
304 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
306 self.column_crypto
307 .get(ordinal)
308 .map(|c| c.is_some())
309 .unwrap_or(false)
310 }
311
312 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
314 self.column_crypto
315 .get(ordinal)?
316 .as_ref()
317 .map(|c| c.encryption_type)
318 }
319}
320
321#[derive(Debug, Clone)]
326pub struct ParameterEncryptionInfo {
327 pub cek_table: CekTable,
329 pub parameters: HashMap<String, ParameterCryptoInfo>,
331}
332
333impl ParameterEncryptionInfo {
334 pub fn new() -> Self {
336 Self {
337 cek_table: CekTable::new(),
338 parameters: HashMap::new(),
339 }
340 }
341
342 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
344 self.parameters.insert(name, info);
345 }
346
347 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
349 self.parameters.get(name)
350 }
351
352 pub fn needs_encryption(&self, name: &str) -> bool {
354 self.parameters.contains_key(name)
355 }
356}
357
358impl Default for ParameterEncryptionInfo {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364#[derive(Debug, Clone)]
366pub struct ParameterCryptoInfo {
367 pub cek_ordinal: u16,
369 pub encryption_type: EncryptionTypeWire,
371 pub algorithm_id: u8,
373 pub column_ordinal: u16,
375 pub database_id: u32,
377}
378
379impl ParameterCryptoInfo {
380 pub fn new(
382 cek_ordinal: u16,
383 encryption_type: EncryptionTypeWire,
384 algorithm_id: u8,
385 column_ordinal: u16,
386 database_id: u32,
387 ) -> Self {
388 Self {
389 cek_ordinal,
390 encryption_type,
391 algorithm_id,
392 column_ordinal,
393 database_id,
394 }
395 }
396}
397
398#[cfg(test)]
399#[allow(clippy::unwrap_used, clippy::expect_used)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_encryption_config_defaults() {
405 let config = EncryptionConfig::new();
406 assert!(config.enabled);
407 assert!(config.cache_ceks);
408 assert!(!config.is_ready()); }
410
411 #[test]
412 fn test_result_set_encryption_info() {
413 let cek_table = CekTable::new();
414 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
415
416 assert!(!info.is_column_encrypted(0));
417 assert!(!info.is_column_encrypted(1));
418 assert!(!info.is_column_encrypted(2));
419
420 let metadata = CryptoMetadata {
421 cek_table_ordinal: 0,
422 algorithm_id: 2,
423 encryption_type: EncryptionTypeWire::Deterministic,
424 normalization_version: 1,
425 };
426
427 info.set_column_crypto(1, metadata);
428 assert!(!info.is_column_encrypted(0));
429 assert!(info.is_column_encrypted(1));
430 assert!(!info.is_column_encrypted(2));
431
432 assert_eq!(
433 info.get_encryption_type(1),
434 Some(EncryptionTypeWire::Deterministic)
435 );
436 }
437
438 #[test]
439 fn test_parameter_encryption_info() {
440 let mut info = ParameterEncryptionInfo::new();
441
442 assert!(!info.needs_encryption("@p1"));
443
444 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
445 info.add_parameter("@p1".to_string(), crypto);
446
447 assert!(info.needs_encryption("@p1"));
448 assert!(!info.needs_encryption("@p2"));
449
450 let param = info.get_parameter("@p1").unwrap();
451 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
452 }
453}