1use crate::common::{CdcError, CdcEvent, Result};
30use aes_gcm::{
31 aead::{Aead, KeyInit, Payload},
32 Aes256Gcm, Nonce,
33};
34use rand::{rngs::OsRng, RngCore};
35use rivven_core::crypto::{KeyInfo, KeyMaterial, KEY_SIZE};
36use serde::{Deserialize, Serialize};
37use sha2::Sha256;
38use std::collections::{HashMap, HashSet};
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use tokio::sync::RwLock;
42use tracing::warn;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
46pub enum EncryptionAlgorithm {
47 #[default]
49 Aes256Gcm,
50 Deterministic,
52}
53
54#[derive(Debug, Clone)]
56pub struct FieldRule {
57 pub table_pattern: String,
59 pub field_name: String,
61 pub algorithm: EncryptionAlgorithm,
63 pub key_id: Option<String>,
65 pub mask_in_logs: bool,
67}
68
69impl FieldRule {
70 pub fn new(table: impl Into<String>, field: impl Into<String>) -> Self {
72 Self {
73 table_pattern: table.into(),
74 field_name: field.into(),
75 algorithm: EncryptionAlgorithm::default(),
76 key_id: None,
77 mask_in_logs: true,
78 }
79 }
80
81 pub fn with_algorithm(mut self, algorithm: EncryptionAlgorithm) -> Self {
83 self.algorithm = algorithm;
84 self
85 }
86
87 pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
89 self.key_id = Some(key_id.into());
90 self
91 }
92
93 pub fn without_log_masking(mut self) -> Self {
95 self.mask_in_logs = false;
96 self
97 }
98
99 pub fn matches_table(&self, table: &str) -> bool {
101 if self.table_pattern == "*" {
102 return true;
103 }
104 if self.table_pattern.ends_with('*') {
105 let prefix = &self.table_pattern[..self.table_pattern.len() - 1];
106 return table.starts_with(prefix);
107 }
108 self.table_pattern == table
109 }
110}
111
112#[derive(Debug, Clone, Default)]
114pub struct EncryptionConfig {
115 pub rules: Vec<FieldRule>,
117 pub default_key_id: String,
119 pub enabled: bool,
121 pub aad_prefix: String,
123}
124
125impl EncryptionConfig {
126 pub fn builder() -> EncryptionConfigBuilder {
127 EncryptionConfigBuilder::default()
128 }
129
130 pub fn rules_for_table(&self, table: &str) -> Vec<&FieldRule> {
132 self.rules
133 .iter()
134 .filter(|r| r.matches_table(table))
135 .collect()
136 }
137
138 pub fn fields_for_table(&self, table: &str) -> HashSet<String> {
140 self.rules_for_table(table)
141 .into_iter()
142 .map(|r| r.field_name.clone())
143 .collect()
144 }
145}
146
147pub struct EncryptionConfigBuilder {
149 config: EncryptionConfig,
150}
151
152impl Default for EncryptionConfigBuilder {
153 fn default() -> Self {
154 Self {
155 config: EncryptionConfig {
156 enabled: true,
157 default_key_id: "default".to_string(),
158 rules: Vec::new(),
159 aad_prefix: String::new(),
160 },
161 }
162 }
163}
164
165impl EncryptionConfigBuilder {
166 pub fn new() -> Self {
167 Self::default()
168 }
169
170 pub fn encrypt_field(mut self, table: impl Into<String>, field: impl Into<String>) -> Self {
172 self.config.rules.push(FieldRule::new(table, field));
173 self
174 }
175
176 pub fn encrypt_field_with(
178 mut self,
179 table: impl Into<String>,
180 field: impl Into<String>,
181 algorithm: EncryptionAlgorithm,
182 ) -> Self {
183 self.config
184 .rules
185 .push(FieldRule::new(table, field).with_algorithm(algorithm));
186 self
187 }
188
189 pub fn add_rule(mut self, rule: FieldRule) -> Self {
191 self.config.rules.push(rule);
192 self
193 }
194
195 pub fn default_key_id(mut self, key_id: impl Into<String>) -> Self {
197 self.config.default_key_id = key_id.into();
198 self
199 }
200
201 pub fn enabled(mut self, enabled: bool) -> Self {
203 self.config.enabled = enabled;
204 self
205 }
206
207 pub fn aad_prefix(mut self, prefix: impl Into<String>) -> Self {
209 self.config.aad_prefix = prefix.into();
210 self
211 }
212
213 pub fn build(self) -> EncryptionConfig {
214 self.config
215 }
216}
217
218pub struct EncryptionKey {
224 pub info: KeyInfo,
226 material: KeyMaterial,
228}
229
230impl EncryptionKey {
231 pub fn id(&self) -> &str {
233 &self.info.id
234 }
235
236 pub fn version(&self) -> u32 {
238 self.info.version
239 }
240
241 pub fn created_at(&self) -> u64 {
243 self.info.created_at
244 }
245
246 pub fn active(&self) -> bool {
248 self.info.active
249 }
250
251 pub fn material(&self) -> &[u8] {
253 self.material.as_bytes()
254 }
255}
256
257impl Clone for EncryptionKey {
258 fn clone(&self) -> Self {
259 Self {
260 info: self.info.clone(),
261 material: self.material.clone(),
262 }
263 }
264}
265
266impl EncryptionKey {
267 pub fn new(id: impl Into<String>, key_material: Vec<u8>) -> Result<Self> {
269 let material = KeyMaterial::from_bytes(&key_material)
270 .ok_or_else(|| CdcError::replication("Key must be 32 bytes for AES-256"))?;
271 Ok(Self {
272 info: KeyInfo::new(id, 1),
273 material,
274 })
275 }
276
277 pub fn generate(id: impl Into<String>) -> Result<Self> {
279 let mut key_bytes = vec![0u8; KEY_SIZE];
280 OsRng.fill_bytes(&mut key_bytes);
281 Self::new(id, key_bytes)
282 }
283
284 fn to_cipher(&self) -> Result<Aes256Gcm> {
286 Aes256Gcm::new_from_slice(self.material.as_bytes())
287 .map_err(|_| CdcError::replication("Invalid key material"))
288 }
289}
290
291#[async_trait::async_trait]
293pub trait KeyProvider: Send + Sync {
294 async fn get_key(&self, key_id: &str) -> Result<Option<EncryptionKey>>;
296
297 async fn get_active_key(&self) -> Result<EncryptionKey>;
299
300 async fn store_key(&self, key: EncryptionKey) -> Result<()>;
302
303 async fn rotate_key(&self, key_id: &str) -> Result<EncryptionKey>;
305}
306
307pub struct MemoryKeyProvider {
309 keys: RwLock<HashMap<String, EncryptionKey>>,
310 active_key_id: RwLock<String>,
311}
312
313impl MemoryKeyProvider {
314 pub fn new() -> Result<Self> {
316 let default_key = EncryptionKey::generate("default")?;
317 let mut keys = HashMap::new();
318 keys.insert("default".to_string(), default_key);
319 Ok(Self {
320 keys: RwLock::new(keys),
321 active_key_id: RwLock::new("default".to_string()),
322 })
323 }
324
325 pub fn with_key(key: EncryptionKey) -> Self {
327 let key_id = key.id().to_string();
328 let mut keys = HashMap::new();
329 keys.insert(key_id.clone(), key);
330 Self {
331 keys: RwLock::new(keys),
332 active_key_id: RwLock::new(key_id),
333 }
334 }
335}
336
337impl Default for MemoryKeyProvider {
338 fn default() -> Self {
339 Self::new().unwrap()
340 }
341}
342
343#[async_trait::async_trait]
344impl KeyProvider for MemoryKeyProvider {
345 async fn get_key(&self, key_id: &str) -> Result<Option<EncryptionKey>> {
346 let keys = self.keys.read().await;
347 Ok(keys.get(key_id).cloned())
348 }
349
350 async fn get_active_key(&self) -> Result<EncryptionKey> {
351 let active_id = self.active_key_id.read().await.clone();
352 self.get_key(&active_id)
353 .await?
354 .ok_or_else(|| CdcError::replication("No active key found"))
355 }
356
357 async fn store_key(&self, key: EncryptionKey) -> Result<()> {
358 let mut keys = self.keys.write().await;
359 keys.insert(key.id().to_string(), key);
360 Ok(())
361 }
362
363 async fn rotate_key(&self, key_id: &str) -> Result<EncryptionKey> {
364 let new_key = EncryptionKey::generate(key_id)?;
365
366 let mut keys = self.keys.write().await;
368 if let Some(old) = keys.get_mut(key_id) {
369 old.info.active = false;
370 }
371
372 let mut versioned_key = new_key;
374 if let Some(old) = keys.get(key_id) {
375 versioned_key.info.version = old.version() + 1;
376 }
377
378 let key_clone = versioned_key.clone();
379 keys.insert(key_id.to_string(), versioned_key);
380
381 *self.active_key_id.write().await = key_id.to_string();
383
384 Ok(key_clone)
385 }
386}
387
388#[derive(Debug, Default)]
390pub struct EncryptionStats {
391 fields_encrypted: AtomicU64,
392 fields_decrypted: AtomicU64,
393 encryption_errors: AtomicU64,
394 decryption_errors: AtomicU64,
395 events_processed: AtomicU64,
396}
397
398impl EncryptionStats {
399 pub fn new() -> Self {
400 Self::default()
401 }
402
403 pub fn record_encrypted(&self, count: u64) {
404 self.fields_encrypted.fetch_add(count, Ordering::Relaxed);
405 }
406
407 pub fn record_decrypted(&self, count: u64) {
408 self.fields_decrypted.fetch_add(count, Ordering::Relaxed);
409 }
410
411 pub fn record_encryption_error(&self) {
412 self.encryption_errors.fetch_add(1, Ordering::Relaxed);
413 }
414
415 pub fn record_decryption_error(&self) {
416 self.decryption_errors.fetch_add(1, Ordering::Relaxed);
417 }
418
419 pub fn record_event(&self) {
420 self.events_processed.fetch_add(1, Ordering::Relaxed);
421 }
422
423 pub fn snapshot(&self) -> EncryptionStatsSnapshot {
424 EncryptionStatsSnapshot {
425 fields_encrypted: self.fields_encrypted.load(Ordering::Relaxed),
426 fields_decrypted: self.fields_decrypted.load(Ordering::Relaxed),
427 encryption_errors: self.encryption_errors.load(Ordering::Relaxed),
428 decryption_errors: self.decryption_errors.load(Ordering::Relaxed),
429 events_processed: self.events_processed.load(Ordering::Relaxed),
430 }
431 }
432}
433
434#[derive(Debug, Clone)]
436pub struct EncryptionStatsSnapshot {
437 pub fields_encrypted: u64,
438 pub fields_decrypted: u64,
439 pub encryption_errors: u64,
440 pub decryption_errors: u64,
441 pub events_processed: u64,
442}
443
444pub struct FieldEncryptor<P: KeyProvider> {
446 config: EncryptionConfig,
447 key_provider: Arc<P>,
448 stats: EncryptionStats,
449}
450
451impl<P: KeyProvider> FieldEncryptor<P> {
452 pub fn new(config: EncryptionConfig, key_provider: P) -> Self {
454 Self {
455 config,
456 key_provider: Arc::new(key_provider),
457 stats: EncryptionStats::new(),
458 }
459 }
460
461 pub async fn encrypt(&self, event: &CdcEvent) -> Result<CdcEvent> {
463 if !self.config.enabled {
464 return Ok(event.clone());
465 }
466
467 self.stats.record_event();
468 let mut result = event.clone();
469 let rules = self.config.rules_for_table(&event.table);
470
471 if rules.is_empty() {
472 return Ok(result);
473 }
474
475 let key = self.key_provider.get_active_key().await?;
476 let cipher = key.to_cipher()?;
477
478 if let Some(ref mut after) = result.after {
480 if let Some(obj) = after.as_object_mut() {
481 let mut encrypted_count = 0u64;
482 for rule in &rules {
483 let field = &rule.field_name;
484 if let Some(value) = obj.get(field) {
485 let plaintext = value.to_string();
486 match self.encrypt_value(
487 &cipher,
488 &plaintext,
489 key.id(),
490 key.material(),
491 rule.algorithm,
492 ) {
493 Ok(ciphertext) => {
494 obj.insert(
495 field.clone(),
496 serde_json::json!({
497 "__encrypted": true,
498 "__key_id": key.id(),
499 "__key_version": key.version(),
500 "__algorithm": format!("{:?}", rule.algorithm),
501 "__value": ciphertext,
502 }),
503 );
504 encrypted_count += 1;
505 }
506 Err(e) => {
507 warn!("Failed to encrypt field {}: {}", field, e);
508 self.stats.record_encryption_error();
509 }
510 }
511 }
512 }
513 self.stats.record_encrypted(encrypted_count);
514 }
515 }
516
517 if let Some(ref mut before) = result.before {
519 if let Some(obj) = before.as_object_mut() {
520 let mut encrypted_count = 0u64;
521 for rule in &rules {
522 let field = &rule.field_name;
523 if let Some(value) = obj.get(field) {
524 let plaintext = value.to_string();
525 match self.encrypt_value(
526 &cipher,
527 &plaintext,
528 key.id(),
529 key.material(),
530 rule.algorithm,
531 ) {
532 Ok(ciphertext) => {
533 obj.insert(
534 field.clone(),
535 serde_json::json!({
536 "__encrypted": true,
537 "__key_id": key.id(),
538 "__key_version": key.version(),
539 "__algorithm": format!("{:?}", rule.algorithm),
540 "__value": ciphertext,
541 }),
542 );
543 encrypted_count += 1;
544 }
545 Err(e) => {
546 warn!("Failed to encrypt field {}: {}", field, e);
547 self.stats.record_encryption_error();
548 }
549 }
550 }
551 }
552 self.stats.record_encrypted(encrypted_count);
553 }
554 }
555
556 Ok(result)
557 }
558
559 pub async fn decrypt(&self, event: &CdcEvent) -> Result<CdcEvent> {
561 if !self.config.enabled {
562 return Ok(event.clone());
563 }
564
565 self.stats.record_event();
566 let mut result = event.clone();
567
568 if let Some(ref mut after) = result.after {
570 if let Some(obj) = after.as_object_mut() {
571 let mut decrypted_count = 0u64;
572 let keys: Vec<_> = obj.keys().cloned().collect();
573
574 for field in keys {
575 if let Some(value) = obj.get(&field) {
576 if let Some(encrypted) = value.as_object() {
577 if encrypted.get("__encrypted") == Some(&serde_json::json!(true)) {
578 if let (Some(key_id), Some(ciphertext)) = (
579 encrypted.get("__key_id").and_then(|v| v.as_str()),
580 encrypted.get("__value").and_then(|v| v.as_str()),
581 ) {
582 match self.decrypt_value(key_id, ciphertext).await {
583 Ok(plaintext) => {
584 let parsed: serde_json::Value = serde_json::from_str(
586 &plaintext,
587 )
588 .unwrap_or_else(|_| serde_json::json!(plaintext));
589 obj.insert(field, parsed);
590 decrypted_count += 1;
591 }
592 Err(e) => {
593 warn!("Failed to decrypt field: {}", e);
594 self.stats.record_decryption_error();
595 }
596 }
597 }
598 }
599 }
600 }
601 }
602 self.stats.record_decrypted(decrypted_count);
603 }
604 }
605
606 if let Some(ref mut before) = result.before {
608 if let Some(obj) = before.as_object_mut() {
609 let mut decrypted_count = 0u64;
610 let keys: Vec<_> = obj.keys().cloned().collect();
611
612 for field in keys {
613 if let Some(value) = obj.get(&field) {
614 if let Some(encrypted) = value.as_object() {
615 if encrypted.get("__encrypted") == Some(&serde_json::json!(true)) {
616 if let (Some(key_id), Some(ciphertext)) = (
617 encrypted.get("__key_id").and_then(|v| v.as_str()),
618 encrypted.get("__value").and_then(|v| v.as_str()),
619 ) {
620 match self.decrypt_value(key_id, ciphertext).await {
621 Ok(plaintext) => {
622 let parsed: serde_json::Value = serde_json::from_str(
623 &plaintext,
624 )
625 .unwrap_or_else(|_| serde_json::json!(plaintext));
626 obj.insert(field, parsed);
627 decrypted_count += 1;
628 }
629 Err(e) => {
630 warn!("Failed to decrypt field: {}", e);
631 self.stats.record_decryption_error();
632 }
633 }
634 }
635 }
636 }
637 }
638 }
639 self.stats.record_decrypted(decrypted_count);
640 }
641 }
642
643 Ok(result)
644 }
645
646 fn encrypt_value(
654 &self,
655 cipher: &Aes256Gcm,
656 plaintext: &str,
657 key_id: &str,
658 key_material: &[u8],
659 algorithm: EncryptionAlgorithm,
660 ) -> Result<String> {
661 let nonce_bytes: [u8; 12] = match algorithm {
662 EncryptionAlgorithm::Aes256Gcm => {
663 let mut n = [0u8; 12];
664 OsRng.fill_bytes(&mut n);
665 n
666 }
667 EncryptionAlgorithm::Deterministic => {
668 use hmac::{Hmac, Mac as HmacMac};
670 type HmacSha256 = Hmac<Sha256>;
671 let mut mac = <HmacSha256 as HmacMac>::new_from_slice(key_material)
672 .map_err(|_| CdcError::replication("HMAC key init failed"))?;
673 mac.update(plaintext.as_bytes());
674 let tag = mac.finalize().into_bytes();
675 let mut n = [0u8; 12];
676 n.copy_from_slice(&tag[..12]);
677 n
678 }
679 };
680 let nonce = Nonce::from_slice(&nonce_bytes);
681
682 let aad = format!("{}:{}", self.config.aad_prefix, key_id);
684
685 let payload = Payload {
687 msg: plaintext.as_bytes(),
688 aad: aad.as_bytes(),
689 };
690 let ciphertext = cipher
691 .encrypt(nonce, payload)
692 .map_err(|_| CdcError::replication("Encryption failed"))?;
693
694 let mut result = nonce_bytes.to_vec();
696 result.extend(ciphertext);
697
698 Ok(base64_encode(&result))
700 }
701
702 async fn decrypt_value(&self, key_id: &str, ciphertext: &str) -> Result<String> {
704 let key = self
705 .key_provider
706 .get_key(key_id)
707 .await?
708 .ok_or_else(|| CdcError::replication(format!("Key not found: {}", key_id)))?;
709 let cipher = key.to_cipher()?;
710
711 let data = base64_decode(ciphertext)?;
713
714 if data.len() < 12 {
715 return Err(CdcError::replication("Invalid ciphertext"));
716 }
717
718 let nonce_bytes: [u8; 12] = data[..12]
720 .try_into()
721 .map_err(|_| CdcError::replication("Invalid nonce"))?;
722 let nonce = Nonce::from_slice(&nonce_bytes);
723 let ciphertext_data = &data[12..];
724
725 let aad = format!("{}:{}", self.config.aad_prefix, key_id);
727
728 let payload = Payload {
730 msg: ciphertext_data,
731 aad: aad.as_bytes(),
732 };
733 let plaintext = cipher
734 .decrypt(nonce, payload)
735 .map_err(|_| CdcError::replication("Decryption failed"))?;
736
737 String::from_utf8(plaintext).map_err(|_| CdcError::replication("Invalid UTF-8"))
738 }
739
740 pub fn stats(&self) -> EncryptionStatsSnapshot {
742 self.stats.snapshot()
743 }
744
745 pub fn is_field_encrypted(value: &serde_json::Value) -> bool {
747 value
748 .as_object()
749 .map(|obj| obj.get("__encrypted") == Some(&serde_json::json!(true)))
750 .unwrap_or(false)
751 }
752}
753
754fn base64_encode(data: &[u8]) -> String {
756 use base64::{engine::general_purpose::STANDARD, Engine};
757 STANDARD.encode(data)
758}
759
760fn base64_decode(s: &str) -> Result<Vec<u8>> {
761 use base64::{engine::general_purpose::STANDARD, Engine};
762 STANDARD
763 .decode(s)
764 .map_err(|e| CdcError::replication(format!("Invalid base64: {}", e)))
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use crate::common::CdcOp;
771
772 fn make_event(table: &str) -> CdcEvent {
773 CdcEvent {
774 source_type: "postgres".to_string(),
775 database: "testdb".to_string(),
776 schema: "public".to_string(),
777 table: table.to_string(),
778 op: CdcOp::Insert,
779 before: None,
780 after: Some(serde_json::json!({
781 "id": 1,
782 "email": "test@example.com",
783 "ssn": "123-45-6789",
784 "name": "John Doe"
785 })),
786 timestamp: chrono::Utc::now().timestamp(),
787 transaction: None,
788 }
789 }
790
791 #[test]
792 fn test_field_rule_creation() {
793 let rule = FieldRule::new("users", "email");
794 assert_eq!(rule.table_pattern, "users");
795 assert_eq!(rule.field_name, "email");
796 assert!(rule.mask_in_logs);
797 }
798
799 #[test]
800 fn test_field_rule_matching() {
801 let rule = FieldRule::new("users", "email");
802 assert!(rule.matches_table("users"));
803 assert!(!rule.matches_table("orders"));
804
805 let wildcard = FieldRule::new("*", "email");
806 assert!(wildcard.matches_table("users"));
807 assert!(wildcard.matches_table("orders"));
808
809 let prefix = FieldRule::new("user*", "email");
810 assert!(prefix.matches_table("users"));
811 assert!(prefix.matches_table("user_profiles"));
812 assert!(!prefix.matches_table("orders"));
813 }
814
815 #[test]
816 fn test_config_builder() {
817 let config = EncryptionConfig::builder()
818 .encrypt_field("users", "email")
819 .encrypt_field("users", "ssn")
820 .encrypt_field("payments", "card_number")
821 .default_key_id("my-key")
822 .build();
823
824 assert_eq!(config.rules.len(), 3);
825 assert_eq!(config.default_key_id, "my-key");
826 assert!(config.enabled);
827 }
828
829 #[test]
830 fn test_config_fields_for_table() {
831 let config = EncryptionConfig::builder()
832 .encrypt_field("users", "email")
833 .encrypt_field("users", "ssn")
834 .encrypt_field("orders", "card_number")
835 .build();
836
837 let user_fields = config.fields_for_table("users");
838 assert_eq!(user_fields.len(), 2);
839 assert!(user_fields.contains("email"));
840 assert!(user_fields.contains("ssn"));
841
842 let order_fields = config.fields_for_table("orders");
843 assert_eq!(order_fields.len(), 1);
844 assert!(order_fields.contains("card_number"));
845
846 let other_fields = config.fields_for_table("products");
847 assert!(other_fields.is_empty());
848 }
849
850 #[test]
851 fn test_encryption_key_generation() {
852 let key = EncryptionKey::generate("test-key").unwrap();
853 assert_eq!(key.id(), "test-key");
854 assert_eq!(key.version(), 1);
855 assert!(key.active());
856 }
857
858 #[test]
859 fn test_encryption_key_validation() {
860 let result = EncryptionKey::new("test", vec![0u8; 16]);
862 assert!(result.is_err());
863
864 let result = EncryptionKey::new("test", vec![0u8; 32]);
866 assert!(result.is_ok());
867 }
868
869 #[tokio::test]
870 async fn test_memory_key_provider() {
871 let provider = MemoryKeyProvider::new().unwrap();
872
873 let key = provider.get_active_key().await.unwrap();
874 assert_eq!(key.id(), "default");
875 assert!(key.active());
876 }
877
878 #[tokio::test]
879 async fn test_memory_key_provider_rotation() {
880 let provider = MemoryKeyProvider::new().unwrap();
881
882 let old_key = provider.get_active_key().await.unwrap();
883 let new_key = provider.rotate_key("default").await.unwrap();
884
885 assert_eq!(new_key.id(), "default");
886 assert_eq!(new_key.version(), old_key.version() + 1);
887 }
888
889 #[tokio::test]
890 async fn test_field_encryptor_encrypt_decrypt() {
891 let config = EncryptionConfig::builder()
892 .encrypt_field("users", "email")
893 .encrypt_field("users", "ssn")
894 .build();
895
896 let provider = MemoryKeyProvider::new().unwrap();
897 let encryptor = FieldEncryptor::new(config, provider);
898
899 let event = make_event("users");
900 let encrypted = encryptor.encrypt(&event).await.unwrap();
901
902 let after = encrypted.after.as_ref().unwrap();
904 assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
905 after.get("email").unwrap()
906 ));
907 assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
908 after.get("ssn").unwrap()
909 ));
910 assert!(!FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
912 after.get("name").unwrap()
913 ));
914
915 let decrypted = encryptor.decrypt(&encrypted).await.unwrap();
917 let after = decrypted.after.as_ref().unwrap();
918
919 assert_eq!(
921 after.get("email").unwrap().as_str().unwrap(),
922 "test@example.com"
923 );
924 assert_eq!(after.get("ssn").unwrap().as_str().unwrap(), "123-45-6789");
925 assert_eq!(after.get("name").unwrap().as_str().unwrap(), "John Doe");
926 }
927
928 #[tokio::test]
929 async fn test_field_encryptor_no_rules() {
930 let config = EncryptionConfig::builder().build();
931 let provider = MemoryKeyProvider::new().unwrap();
932 let encryptor = FieldEncryptor::new(config, provider);
933
934 let event = make_event("users");
935 let encrypted = encryptor.encrypt(&event).await.unwrap();
936
937 let after = encrypted.after.as_ref().unwrap();
939 assert!(!FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
940 after.get("email").unwrap()
941 ));
942 }
943
944 #[tokio::test]
945 async fn test_field_encryptor_disabled() {
946 let config = EncryptionConfig::builder()
947 .encrypt_field("users", "email")
948 .enabled(false)
949 .build();
950 let provider = MemoryKeyProvider::new().unwrap();
951 let encryptor = FieldEncryptor::new(config, provider);
952
953 let event = make_event("users");
954 let encrypted = encryptor.encrypt(&event).await.unwrap();
955
956 let after = encrypted.after.as_ref().unwrap();
958 assert!(!FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
959 after.get("email").unwrap()
960 ));
961 }
962
963 #[test]
964 fn test_stats_snapshot() {
965 let stats = EncryptionStats::new();
966 stats.record_encrypted(10);
967 stats.record_decrypted(8);
968 stats.record_encryption_error();
969 stats.record_decryption_error();
970 stats.record_event();
971 stats.record_event();
972
973 let snapshot = stats.snapshot();
974 assert_eq!(snapshot.fields_encrypted, 10);
975 assert_eq!(snapshot.fields_decrypted, 8);
976 assert_eq!(snapshot.encryption_errors, 1);
977 assert_eq!(snapshot.decryption_errors, 1);
978 assert_eq!(snapshot.events_processed, 2);
979 }
980
981 #[test]
982 fn test_base64_roundtrip() {
983 let data = b"Hello, World!";
984 let encoded = base64_encode(data);
985 let decoded = base64_decode(&encoded).unwrap();
986 assert_eq!(decoded, data);
987
988 let binary = vec![0u8, 1, 2, 255, 254, 253];
990 let encoded = base64_encode(&binary);
991 let decoded = base64_decode(&encoded).unwrap();
992 assert_eq!(decoded, binary);
993 }
994
995 #[tokio::test]
996 async fn test_encrypt_before_and_after() {
997 let config = EncryptionConfig::builder()
998 .encrypt_field("users", "email")
999 .build();
1000 let provider = MemoryKeyProvider::new().unwrap();
1001 let encryptor = FieldEncryptor::new(config, provider);
1002
1003 let mut event = make_event("users");
1004 event.op = CdcOp::Update;
1005 event.before = Some(serde_json::json!({
1006 "id": 1,
1007 "email": "old@example.com"
1008 }));
1009
1010 let encrypted = encryptor.encrypt(&event).await.unwrap();
1011
1012 assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
1014 encrypted.after.as_ref().unwrap().get("email").unwrap()
1015 ));
1016 assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
1017 encrypted.before.as_ref().unwrap().get("email").unwrap()
1018 ));
1019 }
1020}