Skip to main content

rivven_cdc/common/
encryption.rs

1//! # Column-Level Encryption
2//!
3//! Transparent field-level encryption for sensitive CDC data.
4//! Supports multiple encryption algorithms and key management strategies.
5//!
6//! ## Features
7//!
8//! - **AES-256-GCM**: Default authenticated encryption
9//! - **Field Selection**: Encrypt only sensitive fields
10//! - **Key Rotation**: Support for key versioning and rotation
11//! - **Deterministic Encryption**: Optional for searchable encrypted fields
12//! - **Format-Preserving**: Keep data types intact where possible
13//!
14//! ## Usage
15//!
16//! ```ignore
17//! use rivven_cdc::common::encryption::{FieldEncryptor, EncryptionConfig};
18//!
19//! let config = EncryptionConfig::builder()
20//!     .encrypt_field("users", "ssn")
21//!     .encrypt_field("users", "email")
22//!     .encrypt_field("payments", "card_number")
23//!     .build();
24//!
25//! let encryptor = FieldEncryptor::new(config, key_provider);
26//! let encrypted_event = encryptor.encrypt(&event).await?;
27//! ```
28
29use crate::common::{CdcError, CdcEvent, Result};
30use aes_gcm::{
31    aead::{Aead, KeyInit, Payload},
32    Aes256Gcm, Nonce,
33};
34use rand::{rngs::OsRng, RngCore};
35use rivven_core::crypto::{KeyInfo, KeyMaterial, KEY_SIZE};
36use serde::{Deserialize, Serialize};
37use sha2::Sha256;
38use std::collections::{HashMap, HashSet};
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use tokio::sync::RwLock;
42use tracing::warn;
43
44/// Encryption algorithm.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
46pub enum EncryptionAlgorithm {
47    /// AES-256-GCM (authenticated encryption)
48    #[default]
49    Aes256Gcm,
50    /// Deterministic encryption for searchable fields
51    Deterministic,
52}
53
54/// Field encryption rule.
55#[derive(Debug, Clone)]
56pub struct FieldRule {
57    /// Table pattern (glob)
58    pub table_pattern: String,
59    /// Field name
60    pub field_name: String,
61    /// Encryption algorithm
62    pub algorithm: EncryptionAlgorithm,
63    /// Key ID to use
64    pub key_id: Option<String>,
65    /// Whether to mask in logs
66    pub mask_in_logs: bool,
67}
68
69impl FieldRule {
70    /// Create a new field rule.
71    pub fn new(table: impl Into<String>, field: impl Into<String>) -> Self {
72        Self {
73            table_pattern: table.into(),
74            field_name: field.into(),
75            algorithm: EncryptionAlgorithm::default(),
76            key_id: None,
77            mask_in_logs: true,
78        }
79    }
80
81    /// Set encryption algorithm.
82    pub fn with_algorithm(mut self, algorithm: EncryptionAlgorithm) -> Self {
83        self.algorithm = algorithm;
84        self
85    }
86
87    /// Set key ID.
88    pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
89        self.key_id = Some(key_id.into());
90        self
91    }
92
93    /// Disable log masking.
94    pub fn without_log_masking(mut self) -> Self {
95        self.mask_in_logs = false;
96        self
97    }
98
99    /// Check if this rule matches a table.
100    pub fn matches_table(&self, table: &str) -> bool {
101        if self.table_pattern == "*" {
102            return true;
103        }
104        if self.table_pattern.ends_with('*') {
105            let prefix = &self.table_pattern[..self.table_pattern.len() - 1];
106            return table.starts_with(prefix);
107        }
108        self.table_pattern == table
109    }
110}
111
112/// Configuration for field encryption.
113#[derive(Debug, Clone, Default)]
114pub struct EncryptionConfig {
115    /// Field encryption rules
116    pub rules: Vec<FieldRule>,
117    /// Default key ID
118    pub default_key_id: String,
119    /// Enable encryption (can be disabled for testing)
120    pub enabled: bool,
121    /// AAD (additional authenticated data) prefix
122    pub aad_prefix: String,
123}
124
125impl EncryptionConfig {
126    pub fn builder() -> EncryptionConfigBuilder {
127        EncryptionConfigBuilder::default()
128    }
129
130    /// Get rules for a table.
131    pub fn rules_for_table(&self, table: &str) -> Vec<&FieldRule> {
132        self.rules
133            .iter()
134            .filter(|r| r.matches_table(table))
135            .collect()
136    }
137
138    /// Get fields to encrypt for a table.
139    pub fn fields_for_table(&self, table: &str) -> HashSet<String> {
140        self.rules_for_table(table)
141            .into_iter()
142            .map(|r| r.field_name.clone())
143            .collect()
144    }
145}
146
147/// Builder for EncryptionConfig.
148pub struct EncryptionConfigBuilder {
149    config: EncryptionConfig,
150}
151
152impl Default for EncryptionConfigBuilder {
153    fn default() -> Self {
154        Self {
155            config: EncryptionConfig {
156                enabled: true,
157                default_key_id: "default".to_string(),
158                rules: Vec::new(),
159                aad_prefix: String::new(),
160            },
161        }
162    }
163}
164
165impl EncryptionConfigBuilder {
166    pub fn new() -> Self {
167        Self::default()
168    }
169
170    /// Add a field encryption rule.
171    pub fn encrypt_field(mut self, table: impl Into<String>, field: impl Into<String>) -> Self {
172        self.config.rules.push(FieldRule::new(table, field));
173        self
174    }
175
176    /// Add a field rule with custom algorithm.
177    pub fn encrypt_field_with(
178        mut self,
179        table: impl Into<String>,
180        field: impl Into<String>,
181        algorithm: EncryptionAlgorithm,
182    ) -> Self {
183        self.config
184            .rules
185            .push(FieldRule::new(table, field).with_algorithm(algorithm));
186        self
187    }
188
189    /// Add a custom rule.
190    pub fn add_rule(mut self, rule: FieldRule) -> Self {
191        self.config.rules.push(rule);
192        self
193    }
194
195    /// Set default key ID.
196    pub fn default_key_id(mut self, key_id: impl Into<String>) -> Self {
197        self.config.default_key_id = key_id.into();
198        self
199    }
200
201    /// Set enabled state.
202    pub fn enabled(mut self, enabled: bool) -> Self {
203        self.config.enabled = enabled;
204        self
205    }
206
207    /// Set AAD prefix.
208    pub fn aad_prefix(mut self, prefix: impl Into<String>) -> Self {
209        self.config.aad_prefix = prefix.into();
210        self
211    }
212
213    pub fn build(self) -> EncryptionConfig {
214        self.config
215    }
216}
217
218/// Encryption key with metadata.
219///
220/// Wraps [`KeyMaterial`] from `rivven-core::crypto` for the raw key bytes
221/// and [`KeyInfo`] for metadata. Key material is zeroized on drop via
222/// the `KeyMaterial` implementation.
223pub struct EncryptionKey {
224    /// Key metadata (id, version, created_at, active)
225    pub info: KeyInfo,
226    /// Raw key material (32 bytes for AES-256)
227    material: KeyMaterial,
228}
229
230impl EncryptionKey {
231    /// Key ID (convenience accessor)
232    pub fn id(&self) -> &str {
233        &self.info.id
234    }
235
236    /// Key version (convenience accessor)
237    pub fn version(&self) -> u32 {
238        self.info.version
239    }
240
241    /// Creation timestamp (convenience accessor)
242    pub fn created_at(&self) -> u64 {
243        self.info.created_at
244    }
245
246    /// Whether this key is active (convenience accessor)
247    pub fn active(&self) -> bool {
248        self.info.active
249    }
250
251    /// Raw key material bytes (for HMAC / KDF operations).
252    pub fn material(&self) -> &[u8] {
253        self.material.as_bytes()
254    }
255}
256
257impl Clone for EncryptionKey {
258    fn clone(&self) -> Self {
259        Self {
260            info: self.info.clone(),
261            material: self.material.clone(),
262        }
263    }
264}
265
266impl EncryptionKey {
267    /// Create a new encryption key.
268    pub fn new(id: impl Into<String>, key_material: Vec<u8>) -> Result<Self> {
269        let material = KeyMaterial::from_bytes(&key_material)
270            .ok_or_else(|| CdcError::replication("Key must be 32 bytes for AES-256"))?;
271        Ok(Self {
272            info: KeyInfo::new(id, 1),
273            material,
274        })
275    }
276
277    /// Generate a random key using secure random.
278    pub fn generate(id: impl Into<String>) -> Result<Self> {
279        let mut key_bytes = vec![0u8; KEY_SIZE];
280        OsRng.fill_bytes(&mut key_bytes);
281        Self::new(id, key_bytes)
282    }
283
284    /// Create an AEAD cipher.
285    fn to_cipher(&self) -> Result<Aes256Gcm> {
286        Aes256Gcm::new_from_slice(self.material.as_bytes())
287            .map_err(|_| CdcError::replication("Invalid key material"))
288    }
289}
290
291/// Key provider trait.
292#[async_trait::async_trait]
293pub trait KeyProvider: Send + Sync {
294    /// Get a key by ID.
295    async fn get_key(&self, key_id: &str) -> Result<Option<EncryptionKey>>;
296
297    /// Get the active key for encryption.
298    async fn get_active_key(&self) -> Result<EncryptionKey>;
299
300    /// Store a new key.
301    async fn store_key(&self, key: EncryptionKey) -> Result<()>;
302
303    /// Rotate to a new key.
304    async fn rotate_key(&self, key_id: &str) -> Result<EncryptionKey>;
305}
306
307/// In-memory key provider for testing.
308pub struct MemoryKeyProvider {
309    keys: RwLock<HashMap<String, EncryptionKey>>,
310    active_key_id: RwLock<String>,
311}
312
313impl MemoryKeyProvider {
314    /// Create a new memory key provider with a default key.
315    pub fn new() -> Result<Self> {
316        let default_key = EncryptionKey::generate("default")?;
317        let mut keys = HashMap::new();
318        keys.insert("default".to_string(), default_key);
319        Ok(Self {
320            keys: RwLock::new(keys),
321            active_key_id: RwLock::new("default".to_string()),
322        })
323    }
324
325    /// Create with a specific key.
326    pub fn with_key(key: EncryptionKey) -> Self {
327        let key_id = key.id().to_string();
328        let mut keys = HashMap::new();
329        keys.insert(key_id.clone(), key);
330        Self {
331            keys: RwLock::new(keys),
332            active_key_id: RwLock::new(key_id),
333        }
334    }
335}
336
337impl Default for MemoryKeyProvider {
338    fn default() -> Self {
339        Self::new().unwrap()
340    }
341}
342
343#[async_trait::async_trait]
344impl KeyProvider for MemoryKeyProvider {
345    async fn get_key(&self, key_id: &str) -> Result<Option<EncryptionKey>> {
346        let keys = self.keys.read().await;
347        Ok(keys.get(key_id).cloned())
348    }
349
350    async fn get_active_key(&self) -> Result<EncryptionKey> {
351        let active_id = self.active_key_id.read().await.clone();
352        self.get_key(&active_id)
353            .await?
354            .ok_or_else(|| CdcError::replication("No active key found"))
355    }
356
357    async fn store_key(&self, key: EncryptionKey) -> Result<()> {
358        let mut keys = self.keys.write().await;
359        keys.insert(key.id().to_string(), key);
360        Ok(())
361    }
362
363    async fn rotate_key(&self, key_id: &str) -> Result<EncryptionKey> {
364        let new_key = EncryptionKey::generate(key_id)?;
365
366        // Deactivate old key
367        let mut keys = self.keys.write().await;
368        if let Some(old) = keys.get_mut(key_id) {
369            old.info.active = false;
370        }
371
372        // Store new key with incremented version
373        let mut versioned_key = new_key;
374        if let Some(old) = keys.get(key_id) {
375            versioned_key.info.version = old.version() + 1;
376        }
377
378        let key_clone = versioned_key.clone();
379        keys.insert(key_id.to_string(), versioned_key);
380
381        // Update active key ID
382        *self.active_key_id.write().await = key_id.to_string();
383
384        Ok(key_clone)
385    }
386}
387
388/// Encryption statistics.
389#[derive(Debug, Default)]
390pub struct EncryptionStats {
391    fields_encrypted: AtomicU64,
392    fields_decrypted: AtomicU64,
393    encryption_errors: AtomicU64,
394    decryption_errors: AtomicU64,
395    events_processed: AtomicU64,
396}
397
398impl EncryptionStats {
399    pub fn new() -> Self {
400        Self::default()
401    }
402
403    pub fn record_encrypted(&self, count: u64) {
404        self.fields_encrypted.fetch_add(count, Ordering::Relaxed);
405    }
406
407    pub fn record_decrypted(&self, count: u64) {
408        self.fields_decrypted.fetch_add(count, Ordering::Relaxed);
409    }
410
411    pub fn record_encryption_error(&self) {
412        self.encryption_errors.fetch_add(1, Ordering::Relaxed);
413    }
414
415    pub fn record_decryption_error(&self) {
416        self.decryption_errors.fetch_add(1, Ordering::Relaxed);
417    }
418
419    pub fn record_event(&self) {
420        self.events_processed.fetch_add(1, Ordering::Relaxed);
421    }
422
423    pub fn snapshot(&self) -> EncryptionStatsSnapshot {
424        EncryptionStatsSnapshot {
425            fields_encrypted: self.fields_encrypted.load(Ordering::Relaxed),
426            fields_decrypted: self.fields_decrypted.load(Ordering::Relaxed),
427            encryption_errors: self.encryption_errors.load(Ordering::Relaxed),
428            decryption_errors: self.decryption_errors.load(Ordering::Relaxed),
429            events_processed: self.events_processed.load(Ordering::Relaxed),
430        }
431    }
432}
433
434/// Snapshot of encryption statistics.
435#[derive(Debug, Clone)]
436pub struct EncryptionStatsSnapshot {
437    pub fields_encrypted: u64,
438    pub fields_decrypted: u64,
439    pub encryption_errors: u64,
440    pub decryption_errors: u64,
441    pub events_processed: u64,
442}
443
444/// Field encryptor for CDC events.
445pub struct FieldEncryptor<P: KeyProvider> {
446    config: EncryptionConfig,
447    key_provider: Arc<P>,
448    stats: EncryptionStats,
449}
450
451impl<P: KeyProvider> FieldEncryptor<P> {
452    /// Create a new field encryptor.
453    pub fn new(config: EncryptionConfig, key_provider: P) -> Self {
454        Self {
455            config,
456            key_provider: Arc::new(key_provider),
457            stats: EncryptionStats::new(),
458        }
459    }
460
461    /// Encrypt sensitive fields in an event.
462    pub async fn encrypt(&self, event: &CdcEvent) -> Result<CdcEvent> {
463        if !self.config.enabled {
464            return Ok(event.clone());
465        }
466
467        self.stats.record_event();
468        let mut result = event.clone();
469        let rules = self.config.rules_for_table(&event.table);
470
471        if rules.is_empty() {
472            return Ok(result);
473        }
474
475        let key = self.key_provider.get_active_key().await?;
476        let cipher = key.to_cipher()?;
477
478        // Encrypt 'after' fields
479        if let Some(ref mut after) = result.after {
480            if let Some(obj) = after.as_object_mut() {
481                let mut encrypted_count = 0u64;
482                for rule in &rules {
483                    let field = &rule.field_name;
484                    if let Some(value) = obj.get(field) {
485                        let plaintext = value.to_string();
486                        match self.encrypt_value(
487                            &cipher,
488                            &plaintext,
489                            key.id(),
490                            key.material(),
491                            rule.algorithm,
492                        ) {
493                            Ok(ciphertext) => {
494                                obj.insert(
495                                    field.clone(),
496                                    serde_json::json!({
497                                        "__encrypted": true,
498                                        "__key_id": key.id(),
499                                        "__key_version": key.version(),
500                                        "__algorithm": format!("{:?}", rule.algorithm),
501                                        "__value": ciphertext,
502                                    }),
503                                );
504                                encrypted_count += 1;
505                            }
506                            Err(e) => {
507                                warn!("Failed to encrypt field {}: {}", field, e);
508                                self.stats.record_encryption_error();
509                            }
510                        }
511                    }
512                }
513                self.stats.record_encrypted(encrypted_count);
514            }
515        }
516
517        // Encrypt 'before' fields
518        if let Some(ref mut before) = result.before {
519            if let Some(obj) = before.as_object_mut() {
520                let mut encrypted_count = 0u64;
521                for rule in &rules {
522                    let field = &rule.field_name;
523                    if let Some(value) = obj.get(field) {
524                        let plaintext = value.to_string();
525                        match self.encrypt_value(
526                            &cipher,
527                            &plaintext,
528                            key.id(),
529                            key.material(),
530                            rule.algorithm,
531                        ) {
532                            Ok(ciphertext) => {
533                                obj.insert(
534                                    field.clone(),
535                                    serde_json::json!({
536                                        "__encrypted": true,
537                                        "__key_id": key.id(),
538                                        "__key_version": key.version(),
539                                        "__algorithm": format!("{:?}", rule.algorithm),
540                                        "__value": ciphertext,
541                                    }),
542                                );
543                                encrypted_count += 1;
544                            }
545                            Err(e) => {
546                                warn!("Failed to encrypt field {}: {}", field, e);
547                                self.stats.record_encryption_error();
548                            }
549                        }
550                    }
551                }
552                self.stats.record_encrypted(encrypted_count);
553            }
554        }
555
556        Ok(result)
557    }
558
559    /// Decrypt sensitive fields in an event.
560    pub async fn decrypt(&self, event: &CdcEvent) -> Result<CdcEvent> {
561        if !self.config.enabled {
562            return Ok(event.clone());
563        }
564
565        self.stats.record_event();
566        let mut result = event.clone();
567
568        // Decrypt 'after' fields
569        if let Some(ref mut after) = result.after {
570            if let Some(obj) = after.as_object_mut() {
571                let mut decrypted_count = 0u64;
572                let keys: Vec<_> = obj.keys().cloned().collect();
573
574                for field in keys {
575                    if let Some(value) = obj.get(&field) {
576                        if let Some(encrypted) = value.as_object() {
577                            if encrypted.get("__encrypted") == Some(&serde_json::json!(true)) {
578                                if let (Some(key_id), Some(ciphertext)) = (
579                                    encrypted.get("__key_id").and_then(|v| v.as_str()),
580                                    encrypted.get("__value").and_then(|v| v.as_str()),
581                                ) {
582                                    match self.decrypt_value(key_id, ciphertext).await {
583                                        Ok(plaintext) => {
584                                            // Parse back to original JSON type
585                                            let parsed: serde_json::Value = serde_json::from_str(
586                                                &plaintext,
587                                            )
588                                            .unwrap_or_else(|_| serde_json::json!(plaintext));
589                                            obj.insert(field, parsed);
590                                            decrypted_count += 1;
591                                        }
592                                        Err(e) => {
593                                            warn!("Failed to decrypt field: {}", e);
594                                            self.stats.record_decryption_error();
595                                        }
596                                    }
597                                }
598                            }
599                        }
600                    }
601                }
602                self.stats.record_decrypted(decrypted_count);
603            }
604        }
605
606        // Decrypt 'before' fields
607        if let Some(ref mut before) = result.before {
608            if let Some(obj) = before.as_object_mut() {
609                let mut decrypted_count = 0u64;
610                let keys: Vec<_> = obj.keys().cloned().collect();
611
612                for field in keys {
613                    if let Some(value) = obj.get(&field) {
614                        if let Some(encrypted) = value.as_object() {
615                            if encrypted.get("__encrypted") == Some(&serde_json::json!(true)) {
616                                if let (Some(key_id), Some(ciphertext)) = (
617                                    encrypted.get("__key_id").and_then(|v| v.as_str()),
618                                    encrypted.get("__value").and_then(|v| v.as_str()),
619                                ) {
620                                    match self.decrypt_value(key_id, ciphertext).await {
621                                        Ok(plaintext) => {
622                                            let parsed: serde_json::Value = serde_json::from_str(
623                                                &plaintext,
624                                            )
625                                            .unwrap_or_else(|_| serde_json::json!(plaintext));
626                                            obj.insert(field, parsed);
627                                            decrypted_count += 1;
628                                        }
629                                        Err(e) => {
630                                            warn!("Failed to decrypt field: {}", e);
631                                            self.stats.record_decryption_error();
632                                        }
633                                    }
634                                }
635                            }
636                        }
637                    }
638                }
639                self.stats.record_decrypted(decrypted_count);
640            }
641        }
642
643        Ok(result)
644    }
645
646    /// Encrypt a single value.
647    ///
648    /// For `Aes256Gcm` (default): generates a random 12-byte nonce per call.
649    /// For `Deterministic`: derives the nonce from `HMAC-SHA256(key_material, plaintext)`
650    /// truncated to 12 bytes.  Equal plaintexts encrypted with the same key
651    /// produce identical ciphertext — enabling equality searches but leaking
652    /// equality information to anyone with read access to the ciphertext.
653    fn encrypt_value(
654        &self,
655        cipher: &Aes256Gcm,
656        plaintext: &str,
657        key_id: &str,
658        key_material: &[u8],
659        algorithm: EncryptionAlgorithm,
660    ) -> Result<String> {
661        let nonce_bytes: [u8; 12] = match algorithm {
662            EncryptionAlgorithm::Aes256Gcm => {
663                let mut n = [0u8; 12];
664                OsRng.fill_bytes(&mut n);
665                n
666            }
667            EncryptionAlgorithm::Deterministic => {
668                // SIV-like construction: HMAC(key, plaintext) → nonce
669                use hmac::{Hmac, Mac as HmacMac};
670                type HmacSha256 = Hmac<Sha256>;
671                let mut mac = <HmacSha256 as HmacMac>::new_from_slice(key_material)
672                    .map_err(|_| CdcError::replication("HMAC key init failed"))?;
673                mac.update(plaintext.as_bytes());
674                let tag = mac.finalize().into_bytes();
675                let mut n = [0u8; 12];
676                n.copy_from_slice(&tag[..12]);
677                n
678            }
679        };
680        let nonce = Nonce::from_slice(&nonce_bytes);
681
682        // AAD includes key_id for additional authentication
683        let aad = format!("{}:{}", self.config.aad_prefix, key_id);
684
685        // Encrypt with AAD (Additional Authenticated Data)
686        let payload = Payload {
687            msg: plaintext.as_bytes(),
688            aad: aad.as_bytes(),
689        };
690        let ciphertext = cipher
691            .encrypt(nonce, payload)
692            .map_err(|_| CdcError::replication("Encryption failed"))?;
693
694        // Prepend nonce to ciphertext
695        let mut result = nonce_bytes.to_vec();
696        result.extend(ciphertext);
697
698        // Base64 encode
699        Ok(base64_encode(&result))
700    }
701
702    /// Decrypt a single value.
703    async fn decrypt_value(&self, key_id: &str, ciphertext: &str) -> Result<String> {
704        let key = self
705            .key_provider
706            .get_key(key_id)
707            .await?
708            .ok_or_else(|| CdcError::replication(format!("Key not found: {}", key_id)))?;
709        let cipher = key.to_cipher()?;
710
711        // Base64 decode
712        let data = base64_decode(ciphertext)?;
713
714        if data.len() < 12 {
715            return Err(CdcError::replication("Invalid ciphertext"));
716        }
717
718        // Extract nonce and ciphertext
719        let nonce_bytes: [u8; 12] = data[..12]
720            .try_into()
721            .map_err(|_| CdcError::replication("Invalid nonce"))?;
722        let nonce = Nonce::from_slice(&nonce_bytes);
723        let ciphertext_data = &data[12..];
724
725        // AAD must match what was used during encryption
726        let aad = format!("{}:{}", self.config.aad_prefix, key_id);
727
728        // Decrypt with AAD verification
729        let payload = Payload {
730            msg: ciphertext_data,
731            aad: aad.as_bytes(),
732        };
733        let plaintext = cipher
734            .decrypt(nonce, payload)
735            .map_err(|_| CdcError::replication("Decryption failed"))?;
736
737        String::from_utf8(plaintext).map_err(|_| CdcError::replication("Invalid UTF-8"))
738    }
739
740    /// Get statistics.
741    pub fn stats(&self) -> EncryptionStatsSnapshot {
742        self.stats.snapshot()
743    }
744
745    /// Check if a field is encrypted.
746    pub fn is_field_encrypted(value: &serde_json::Value) -> bool {
747        value
748            .as_object()
749            .map(|obj| obj.get("__encrypted") == Some(&serde_json::json!(true)))
750            .unwrap_or(false)
751    }
752}
753
754// Base64 encoding/decoding helpers (using the `base64` crate)
755fn base64_encode(data: &[u8]) -> String {
756    use base64::{engine::general_purpose::STANDARD, Engine};
757    STANDARD.encode(data)
758}
759
760fn base64_decode(s: &str) -> Result<Vec<u8>> {
761    use base64::{engine::general_purpose::STANDARD, Engine};
762    STANDARD
763        .decode(s)
764        .map_err(|e| CdcError::replication(format!("Invalid base64: {}", e)))
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use crate::common::CdcOp;
771
772    fn make_event(table: &str) -> CdcEvent {
773        CdcEvent {
774            source_type: "postgres".to_string(),
775            database: "testdb".to_string(),
776            schema: "public".to_string(),
777            table: table.to_string(),
778            op: CdcOp::Insert,
779            before: None,
780            after: Some(serde_json::json!({
781                "id": 1,
782                "email": "test@example.com",
783                "ssn": "123-45-6789",
784                "name": "John Doe"
785            })),
786            timestamp: chrono::Utc::now().timestamp(),
787            transaction: None,
788        }
789    }
790
791    #[test]
792    fn test_field_rule_creation() {
793        let rule = FieldRule::new("users", "email");
794        assert_eq!(rule.table_pattern, "users");
795        assert_eq!(rule.field_name, "email");
796        assert!(rule.mask_in_logs);
797    }
798
799    #[test]
800    fn test_field_rule_matching() {
801        let rule = FieldRule::new("users", "email");
802        assert!(rule.matches_table("users"));
803        assert!(!rule.matches_table("orders"));
804
805        let wildcard = FieldRule::new("*", "email");
806        assert!(wildcard.matches_table("users"));
807        assert!(wildcard.matches_table("orders"));
808
809        let prefix = FieldRule::new("user*", "email");
810        assert!(prefix.matches_table("users"));
811        assert!(prefix.matches_table("user_profiles"));
812        assert!(!prefix.matches_table("orders"));
813    }
814
815    #[test]
816    fn test_config_builder() {
817        let config = EncryptionConfig::builder()
818            .encrypt_field("users", "email")
819            .encrypt_field("users", "ssn")
820            .encrypt_field("payments", "card_number")
821            .default_key_id("my-key")
822            .build();
823
824        assert_eq!(config.rules.len(), 3);
825        assert_eq!(config.default_key_id, "my-key");
826        assert!(config.enabled);
827    }
828
829    #[test]
830    fn test_config_fields_for_table() {
831        let config = EncryptionConfig::builder()
832            .encrypt_field("users", "email")
833            .encrypt_field("users", "ssn")
834            .encrypt_field("orders", "card_number")
835            .build();
836
837        let user_fields = config.fields_for_table("users");
838        assert_eq!(user_fields.len(), 2);
839        assert!(user_fields.contains("email"));
840        assert!(user_fields.contains("ssn"));
841
842        let order_fields = config.fields_for_table("orders");
843        assert_eq!(order_fields.len(), 1);
844        assert!(order_fields.contains("card_number"));
845
846        let other_fields = config.fields_for_table("products");
847        assert!(other_fields.is_empty());
848    }
849
850    #[test]
851    fn test_encryption_key_generation() {
852        let key = EncryptionKey::generate("test-key").unwrap();
853        assert_eq!(key.id(), "test-key");
854        assert_eq!(key.version(), 1);
855        assert!(key.active());
856    }
857
858    #[test]
859    fn test_encryption_key_validation() {
860        // Too short
861        let result = EncryptionKey::new("test", vec![0u8; 16]);
862        assert!(result.is_err());
863
864        // Correct size
865        let result = EncryptionKey::new("test", vec![0u8; 32]);
866        assert!(result.is_ok());
867    }
868
869    #[tokio::test]
870    async fn test_memory_key_provider() {
871        let provider = MemoryKeyProvider::new().unwrap();
872
873        let key = provider.get_active_key().await.unwrap();
874        assert_eq!(key.id(), "default");
875        assert!(key.active());
876    }
877
878    #[tokio::test]
879    async fn test_memory_key_provider_rotation() {
880        let provider = MemoryKeyProvider::new().unwrap();
881
882        let old_key = provider.get_active_key().await.unwrap();
883        let new_key = provider.rotate_key("default").await.unwrap();
884
885        assert_eq!(new_key.id(), "default");
886        assert_eq!(new_key.version(), old_key.version() + 1);
887    }
888
889    #[tokio::test]
890    async fn test_field_encryptor_encrypt_decrypt() {
891        let config = EncryptionConfig::builder()
892            .encrypt_field("users", "email")
893            .encrypt_field("users", "ssn")
894            .build();
895
896        let provider = MemoryKeyProvider::new().unwrap();
897        let encryptor = FieldEncryptor::new(config, provider);
898
899        let event = make_event("users");
900        let encrypted = encryptor.encrypt(&event).await.unwrap();
901
902        // Check that fields are encrypted
903        let after = encrypted.after.as_ref().unwrap();
904        assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
905            after.get("email").unwrap()
906        ));
907        assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
908            after.get("ssn").unwrap()
909        ));
910        // Name should not be encrypted
911        assert!(!FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
912            after.get("name").unwrap()
913        ));
914
915        // Decrypt
916        let decrypted = encryptor.decrypt(&encrypted).await.unwrap();
917        let after = decrypted.after.as_ref().unwrap();
918
919        // Values should be restored (compare as JSON strings)
920        assert_eq!(
921            after.get("email").unwrap().as_str().unwrap(),
922            "test@example.com"
923        );
924        assert_eq!(after.get("ssn").unwrap().as_str().unwrap(), "123-45-6789");
925        assert_eq!(after.get("name").unwrap().as_str().unwrap(), "John Doe");
926    }
927
928    #[tokio::test]
929    async fn test_field_encryptor_no_rules() {
930        let config = EncryptionConfig::builder().build();
931        let provider = MemoryKeyProvider::new().unwrap();
932        let encryptor = FieldEncryptor::new(config, provider);
933
934        let event = make_event("users");
935        let encrypted = encryptor.encrypt(&event).await.unwrap();
936
937        // Nothing should be encrypted
938        let after = encrypted.after.as_ref().unwrap();
939        assert!(!FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
940            after.get("email").unwrap()
941        ));
942    }
943
944    #[tokio::test]
945    async fn test_field_encryptor_disabled() {
946        let config = EncryptionConfig::builder()
947            .encrypt_field("users", "email")
948            .enabled(false)
949            .build();
950        let provider = MemoryKeyProvider::new().unwrap();
951        let encryptor = FieldEncryptor::new(config, provider);
952
953        let event = make_event("users");
954        let encrypted = encryptor.encrypt(&event).await.unwrap();
955
956        // Nothing should be encrypted when disabled
957        let after = encrypted.after.as_ref().unwrap();
958        assert!(!FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
959            after.get("email").unwrap()
960        ));
961    }
962
963    #[test]
964    fn test_stats_snapshot() {
965        let stats = EncryptionStats::new();
966        stats.record_encrypted(10);
967        stats.record_decrypted(8);
968        stats.record_encryption_error();
969        stats.record_decryption_error();
970        stats.record_event();
971        stats.record_event();
972
973        let snapshot = stats.snapshot();
974        assert_eq!(snapshot.fields_encrypted, 10);
975        assert_eq!(snapshot.fields_decrypted, 8);
976        assert_eq!(snapshot.encryption_errors, 1);
977        assert_eq!(snapshot.decryption_errors, 1);
978        assert_eq!(snapshot.events_processed, 2);
979    }
980
981    #[test]
982    fn test_base64_roundtrip() {
983        let data = b"Hello, World!";
984        let encoded = base64_encode(data);
985        let decoded = base64_decode(&encoded).unwrap();
986        assert_eq!(decoded, data);
987
988        // Test with binary data
989        let binary = vec![0u8, 1, 2, 255, 254, 253];
990        let encoded = base64_encode(&binary);
991        let decoded = base64_decode(&encoded).unwrap();
992        assert_eq!(decoded, binary);
993    }
994
995    #[tokio::test]
996    async fn test_encrypt_before_and_after() {
997        let config = EncryptionConfig::builder()
998            .encrypt_field("users", "email")
999            .build();
1000        let provider = MemoryKeyProvider::new().unwrap();
1001        let encryptor = FieldEncryptor::new(config, provider);
1002
1003        let mut event = make_event("users");
1004        event.op = CdcOp::Update;
1005        event.before = Some(serde_json::json!({
1006            "id": 1,
1007            "email": "old@example.com"
1008        }));
1009
1010        let encrypted = encryptor.encrypt(&event).await.unwrap();
1011
1012        // Both before and after should have encrypted email
1013        assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
1014            encrypted.after.as_ref().unwrap().get("email").unwrap()
1015        ));
1016        assert!(FieldEncryptor::<MemoryKeyProvider>::is_field_encrypted(
1017            encrypted.before.as_ref().unwrap().get("email").unwrap()
1018        ));
1019    }
1020}