1use std::collections::HashMap;
48
49use mssql_auth::KeyStoreProvider;
50use tds_protocol::crypto::{CekTable, CekTableEntry, CryptoMetadata, EncryptionTypeWire};
51
52#[cfg(feature = "always-encrypted")]
53use mssql_auth::{AeadEncryptor, CekCache, CekCacheKey, EncryptionError};
54#[cfg(feature = "always-encrypted")]
55use std::sync::Arc;
56
57#[derive(Default)]
59pub struct EncryptionConfig {
60 pub enabled: bool,
62 providers: Vec<Box<dyn KeyStoreProvider>>,
64 pub cache_ceks: bool,
66}
67
68impl EncryptionConfig {
69 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 enabled: true,
74 providers: Vec::new(),
75 cache_ceks: true,
76 }
77 }
78
79 pub fn register_provider(&mut self, provider: impl KeyStoreProvider + 'static) {
81 self.providers.push(Box::new(provider));
82 }
83
84 #[must_use]
86 pub fn with_provider(mut self, provider: impl KeyStoreProvider + 'static) -> Self {
87 self.register_provider(provider);
88 self
89 }
90
91 #[must_use]
93 pub fn with_cek_caching(mut self, enabled: bool) -> Self {
94 self.cache_ceks = enabled;
95 self
96 }
97
98 pub fn get_provider(&self, name: &str) -> Option<&dyn KeyStoreProvider> {
100 self.providers
101 .iter()
102 .find(|p| p.provider_name() == name)
103 .map(|p| p.as_ref())
104 }
105
106 #[must_use]
108 pub fn is_ready(&self) -> bool {
109 self.enabled && !self.providers.is_empty()
110 }
111}
112
113impl std::fmt::Debug for EncryptionConfig {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("EncryptionConfig")
116 .field("enabled", &self.enabled)
117 .field("provider_count", &self.providers.len())
118 .field("cache_ceks", &self.cache_ceks)
119 .finish()
120 }
121}
122
123#[cfg(feature = "always-encrypted")]
128pub struct EncryptionContext {
129 providers: HashMap<String, Box<dyn KeyStoreProvider>>,
131 cek_cache: CekCache,
133 cache_enabled: bool,
135}
136
137#[cfg(feature = "always-encrypted")]
138impl EncryptionContext {
139 pub fn new(config: EncryptionConfig) -> Self {
141 let providers = config
142 .providers
143 .into_iter()
144 .map(|p| (p.provider_name().to_string(), p))
145 .collect();
146
147 Self {
148 providers,
149 cek_cache: CekCache::new(),
150 cache_enabled: config.cache_ceks,
151 }
152 }
153
154 pub async fn get_encryptor(
161 &self,
162 cek_entry: &CekTableEntry,
163 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
164 let cache_key = CekCacheKey::new(
165 cek_entry.database_id,
166 cek_entry.cek_id,
167 cek_entry.cek_version,
168 );
169
170 if self.cache_enabled {
172 if let Some(encryptor) = self.cek_cache.get(&cache_key) {
173 return Ok(encryptor);
174 }
175 }
176
177 let cek_value = cek_entry
179 .primary_value()
180 .ok_or_else(|| EncryptionError::CekDecryptionFailed("No CEK value available".into()))?;
181
182 let provider = self
184 .providers
185 .get(&cek_value.key_store_provider_name)
186 .ok_or_else(|| {
187 EncryptionError::KeyStoreNotFound(cek_value.key_store_provider_name.clone())
188 })?;
189
190 let decrypted_cek = provider
192 .decrypt_cek(
193 &cek_value.cmk_path,
194 &cek_value.encryption_algorithm,
195 &cek_value.encrypted_value,
196 )
197 .await?;
198
199 if self.cache_enabled {
201 self.cek_cache.insert(cache_key, decrypted_cek)
202 } else {
203 Ok(Arc::new(AeadEncryptor::new(&decrypted_cek)?))
205 }
206 }
207
208 pub async fn encrypt_value(
216 &self,
217 plaintext: &[u8],
218 cek_entry: &CekTableEntry,
219 encryption_type: EncryptionTypeWire,
220 ) -> Result<Vec<u8>, EncryptionError> {
221 let encryptor = self.get_encryptor(cek_entry).await?;
222
223 let enc_type = match encryption_type {
224 EncryptionTypeWire::Deterministic => mssql_auth::EncryptionType::Deterministic,
225 EncryptionTypeWire::Randomized => mssql_auth::EncryptionType::Randomized,
226 _ => {
227 return Err(EncryptionError::UnsupportedOperation(format!(
228 "unsupported encryption type: {encryption_type:?}"
229 )));
230 }
231 };
232
233 encryptor.encrypt(plaintext, enc_type)
234 }
235
236 pub async fn decrypt_value(
243 &self,
244 ciphertext: &[u8],
245 cek_entry: &CekTableEntry,
246 ) -> Result<Vec<u8>, EncryptionError> {
247 let encryptor = self.get_encryptor(cek_entry).await?;
248 encryptor.decrypt(ciphertext)
249 }
250
251 pub fn clear_cache(&self) {
255 self.cek_cache.clear();
256 }
257
258 pub fn has_provider(&self, name: &str) -> bool {
260 self.providers.contains_key(name)
261 }
262}
263
264#[cfg(feature = "always-encrypted")]
265impl std::fmt::Debug for EncryptionContext {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("EncryptionContext")
268 .field("providers", &self.providers.keys().collect::<Vec<_>>())
269 .field("cache_entries", &self.cek_cache.len())
270 .field("cache_enabled", &self.cache_enabled)
271 .finish()
272 }
273}
274
275#[derive(Debug, Clone)]
280pub struct ResultSetEncryptionInfo {
281 pub cek_table: CekTable,
283 pub column_crypto: Vec<Option<CryptoMetadata>>,
285}
286
287impl ResultSetEncryptionInfo {
288 pub fn new(cek_table: CekTable, column_count: usize) -> Self {
290 Self {
291 cek_table,
292 column_crypto: vec![None; column_count],
293 }
294 }
295
296 pub fn set_column_crypto(&mut self, ordinal: usize, metadata: CryptoMetadata) {
298 if ordinal < self.column_crypto.len() {
299 self.column_crypto[ordinal] = Some(metadata);
300 }
301 }
302
303 pub fn get_cek_for_column(&self, ordinal: usize) -> Option<&CekTableEntry> {
305 let crypto = self.column_crypto.get(ordinal)?.as_ref()?;
306 self.cek_table.get(crypto.cek_table_ordinal)
307 }
308
309 pub fn is_column_encrypted(&self, ordinal: usize) -> bool {
311 self.column_crypto
312 .get(ordinal)
313 .map(|c| c.is_some())
314 .unwrap_or(false)
315 }
316
317 pub fn get_encryption_type(&self, ordinal: usize) -> Option<EncryptionTypeWire> {
319 self.column_crypto
320 .get(ordinal)?
321 .as_ref()
322 .map(|c| c.encryption_type)
323 }
324}
325
326#[derive(Debug, Clone)]
331pub struct ParameterEncryptionInfo {
332 pub cek_table: CekTable,
334 pub parameters: HashMap<String, ParameterCryptoInfo>,
336}
337
338impl ParameterEncryptionInfo {
339 pub fn new() -> Self {
341 Self {
342 cek_table: CekTable::new(),
343 parameters: HashMap::new(),
344 }
345 }
346
347 pub fn add_parameter(&mut self, name: String, info: ParameterCryptoInfo) {
349 self.parameters.insert(name, info);
350 }
351
352 pub fn get_parameter(&self, name: &str) -> Option<&ParameterCryptoInfo> {
354 self.parameters.get(name)
355 }
356
357 pub fn needs_encryption(&self, name: &str) -> bool {
359 self.parameters.contains_key(name)
360 }
361}
362
363impl Default for ParameterEncryptionInfo {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369#[derive(Debug, Clone)]
371pub struct ParameterCryptoInfo {
372 pub cek_ordinal: u16,
374 pub encryption_type: EncryptionTypeWire,
376 pub algorithm_id: u8,
378 pub column_ordinal: u16,
380 pub database_id: u32,
382}
383
384impl ParameterCryptoInfo {
385 pub fn new(
387 cek_ordinal: u16,
388 encryption_type: EncryptionTypeWire,
389 algorithm_id: u8,
390 column_ordinal: u16,
391 database_id: u32,
392 ) -> Self {
393 Self {
394 cek_ordinal,
395 encryption_type,
396 algorithm_id,
397 column_ordinal,
398 database_id,
399 }
400 }
401}
402
403#[cfg(test)]
404#[allow(clippy::unwrap_used, clippy::expect_used)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_encryption_config_defaults() {
410 let config = EncryptionConfig::new();
411 assert!(config.enabled);
412 assert!(config.cache_ceks);
413 assert!(!config.is_ready()); }
415
416 #[test]
417 fn test_result_set_encryption_info() {
418 let cek_table = CekTable::new();
419 let mut info = ResultSetEncryptionInfo::new(cek_table, 3);
420
421 assert!(!info.is_column_encrypted(0));
422 assert!(!info.is_column_encrypted(1));
423 assert!(!info.is_column_encrypted(2));
424
425 let metadata = CryptoMetadata {
426 cek_table_ordinal: 0,
427 algorithm_id: 2,
428 encryption_type: EncryptionTypeWire::Deterministic,
429 normalization_version: 1,
430 };
431
432 info.set_column_crypto(1, metadata);
433 assert!(!info.is_column_encrypted(0));
434 assert!(info.is_column_encrypted(1));
435 assert!(!info.is_column_encrypted(2));
436
437 assert_eq!(
438 info.get_encryption_type(1),
439 Some(EncryptionTypeWire::Deterministic)
440 );
441 }
442
443 #[test]
444 fn test_parameter_encryption_info() {
445 let mut info = ParameterEncryptionInfo::new();
446
447 assert!(!info.needs_encryption("@p1"));
448
449 let crypto = ParameterCryptoInfo::new(0, EncryptionTypeWire::Randomized, 2, 1, 1);
450 info.add_parameter("@p1".to_string(), crypto);
451
452 assert!(info.needs_encryption("@p1"));
453 assert!(!info.needs_encryption("@p2"));
454
455 let param = info.get_parameter("@p1").unwrap();
456 assert_eq!(param.encryption_type, EncryptionTypeWire::Randomized);
457 }
458}