Skip to main content

neleus_db/
encryption.rs

1use aes_gcm::aead::{Aead, KeyInit};
2use aes_gcm::{Aes256Gcm, Nonce as AesNonce};
3use anyhow::{Result, anyhow};
4use chacha20poly1305::{ChaCha20Poly1305, Nonce as ChaChaNonce};
5use pbkdf2::pbkdf2_hmac;
6use serde::{Deserialize, Serialize};
7use sha2::Sha256;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct EncryptionConfig {
12    pub enabled: bool,
13    pub algorithm: String,
14    pub kdf: String,
15    pub key_size: usize,
16    pub salt_size: usize,
17    #[serde(default = "default_nonce_size")]
18    pub nonce_size: usize,
19    #[serde(default = "default_kdf_iterations")]
20    pub kdf_iterations: u32,
21}
22
23impl Default for EncryptionConfig {
24    fn default() -> Self {
25        Self {
26            enabled: false,
27            algorithm: "aes-256-gcm".to_string(),
28            kdf: "pbkdf2".to_string(),
29            key_size: 32,
30            salt_size: 16,
31            nonce_size: default_nonce_size(),
32            kdf_iterations: default_kdf_iterations(),
33        }
34    }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct EncryptedData {
39    pub version: u32,
40    pub algorithm: String,
41    #[serde(default = "default_kdf")]
42    pub kdf: String,
43    #[serde(default = "default_kdf_iterations")]
44    pub iterations: u32,
45    #[serde(with = "hex_serde")]
46    pub salt: Vec<u8>,
47    #[serde(with = "hex_serde")]
48    pub nonce: Vec<u8>,
49    #[serde(with = "hex_serde")]
50    pub ciphertext: Vec<u8>,
51    #[serde(default = "now_unix")]
52    pub created_at: u64,
53    #[serde(default)]
54    pub metadata: std::collections::BTreeMap<String, String>,
55}
56
57fn default_kdf() -> String {
58    "pbkdf2".to_string()
59}
60
61fn default_kdf_iterations() -> u32 {
62    210_000
63}
64
65fn default_nonce_size() -> usize {
66    12
67}
68
69mod hex_serde {
70    use serde::{Deserialize, Deserializer, Serializer};
71
72    pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: Serializer,
75    {
76        serializer.serialize_str(&hex::encode(bytes))
77    }
78
79    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
80    where
81        D: Deserializer<'de>,
82    {
83        let s = String::deserialize(deserializer)?;
84        hex::decode(&s).map_err(serde::de::Error::custom)
85    }
86}
87
88pub trait EncryptionProvider: Send + Sync {
89    fn encrypt(&self, plaintext: &[u8], password: &str) -> Result<EncryptedData>;
90    fn decrypt(&self, encrypted: &EncryptedData, password: &str) -> Result<Vec<u8>>;
91    fn algorithm(&self) -> &str;
92}
93
94pub struct NoOpEncryption;
95
96impl EncryptionProvider for NoOpEncryption {
97    fn encrypt(&self, plaintext: &[u8], _password: &str) -> Result<EncryptedData> {
98        Ok(EncryptedData {
99            version: 1,
100            algorithm: "none".to_string(),
101            kdf: "none".to_string(),
102            iterations: 0,
103            salt: Vec::new(),
104            nonce: Vec::new(),
105            ciphertext: plaintext.to_vec(),
106            created_at: now_unix(),
107            metadata: Default::default(),
108        })
109    }
110
111    fn decrypt(&self, encrypted: &EncryptedData, _password: &str) -> Result<Vec<u8>> {
112        if encrypted.algorithm != "none" {
113            return Err(anyhow!(
114                "NoOp provider can only decrypt algorithm 'none', got '{}'",
115                encrypted.algorithm
116            ));
117        }
118        Ok(encrypted.ciphertext.clone())
119    }
120
121    fn algorithm(&self) -> &str {
122        "none"
123    }
124}
125
126pub struct EncryptionManager {
127    provider: Box<dyn EncryptionProvider>,
128    config: EncryptionConfig,
129}
130
131impl EncryptionManager {
132    pub fn new(config: EncryptionConfig, provider: Box<dyn EncryptionProvider>) -> Result<Self> {
133        if !config.enabled {
134            return Ok(Self {
135                provider: Box::new(NoOpEncryption),
136                config,
137            });
138        }
139
140        Ok(Self { provider, config })
141    }
142
143    pub fn from_config(config: EncryptionConfig) -> Result<Self> {
144        if !config.enabled {
145            return Ok(Self {
146                provider: Box::new(NoOpEncryption),
147                config,
148            });
149        }
150
151        let normalized = config.algorithm.to_ascii_lowercase();
152        let provider: Box<dyn EncryptionProvider> = match normalized.as_str() {
153            "aes-256-gcm" => Box::new(Aes256GcmEncryption::new(config.clone())?),
154            "chacha20-poly1305" => Box::new(ChaCha20Poly1305Encryption::new(config.clone())?),
155            "none" => Box::new(NoOpEncryption),
156            _ => {
157                return Err(anyhow!(
158                    "unsupported encryption algorithm '{}'; expected aes-256-gcm or chacha20-poly1305",
159                    config.algorithm
160                ));
161            }
162        };
163
164        Ok(Self { provider, config })
165    }
166
167    pub fn disabled() -> Self {
168        Self {
169            provider: Box::new(NoOpEncryption),
170            config: EncryptionConfig::default(),
171        }
172    }
173
174    pub fn encrypt(&self, plaintext: &[u8], password: &str) -> Result<Vec<u8>> {
175        let encrypted = self.provider.encrypt(plaintext, password)?;
176        let serialized = serde_json::to_vec(&encrypted)?;
177        Ok(serialized)
178    }
179
180    pub fn decrypt(&self, ciphertext: &[u8], password: &str) -> Result<Vec<u8>> {
181        let encrypted: EncryptedData = serde_json::from_slice(ciphertext)?;
182        self.provider.decrypt(&encrypted, password)
183    }
184
185    pub fn is_enabled(&self) -> bool {
186        self.config.enabled
187    }
188
189    pub fn config(&self) -> &EncryptionConfig {
190        &self.config
191    }
192}
193
194pub struct Aes256GcmEncryption {
195    config: EncryptionConfig,
196}
197
198impl Aes256GcmEncryption {
199    pub fn new(config: EncryptionConfig) -> Result<Self> {
200        validate_aead_config(&config, "aes-256-gcm")?;
201        Ok(Self { config })
202    }
203}
204
205impl EncryptionProvider for Aes256GcmEncryption {
206    fn encrypt(&self, plaintext: &[u8], password: &str) -> Result<EncryptedData> {
207        let salt = utils::random_bytes(self.config.salt_size)?;
208        let nonce = utils::random_bytes(self.config.nonce_size)?;
209        let key = derive_key(
210            password,
211            &salt,
212            &self.config.kdf,
213            self.config.kdf_iterations,
214            self.config.key_size,
215        )?;
216
217        let cipher = Aes256Gcm::new_from_slice(&key)
218            .map_err(|_| anyhow!("invalid AES-256-GCM key size"))?;
219        let nonce_ref = AesNonce::from_slice(&nonce);
220
221        let ciphertext = cipher
222            .encrypt(nonce_ref, plaintext)
223            .map_err(|e| anyhow!("AES-256-GCM encryption failed: {e}"))?;
224
225        Ok(EncryptedData {
226            version: 1,
227            algorithm: self.algorithm().to_string(),
228            kdf: self.config.kdf.clone(),
229            iterations: self.config.kdf_iterations,
230            salt,
231            nonce,
232            ciphertext,
233            created_at: now_unix(),
234            metadata: Default::default(),
235        })
236    }
237
238    fn decrypt(&self, encrypted: &EncryptedData, password: &str) -> Result<Vec<u8>> {
239        if encrypted.algorithm != self.algorithm() {
240            return Err(anyhow!(
241                "algorithm mismatch: provider={} payload={}",
242                self.algorithm(),
243                encrypted.algorithm
244            ));
245        }
246        if encrypted.nonce.len() != self.config.nonce_size {
247            return Err(anyhow!(
248                "invalid nonce size: expected {}, got {}",
249                self.config.nonce_size,
250                encrypted.nonce.len()
251            ));
252        }
253
254        let kdf = if encrypted.kdf.is_empty() {
255            &self.config.kdf
256        } else {
257            &encrypted.kdf
258        };
259        let iterations = if encrypted.iterations == 0 {
260            self.config.kdf_iterations
261        } else {
262            encrypted.iterations
263        };
264
265        let key = derive_key(
266            password,
267            &encrypted.salt,
268            kdf,
269            iterations,
270            self.config.key_size,
271        )?;
272
273        let cipher = Aes256Gcm::new_from_slice(&key)
274            .map_err(|_| anyhow!("invalid AES-256-GCM key size"))?;
275        let nonce_ref = AesNonce::from_slice(&encrypted.nonce);
276
277        cipher
278            .decrypt(nonce_ref, encrypted.ciphertext.as_ref())
279            .map_err(|_| anyhow!("AES-256-GCM authentication failed (wrong password or tampered data)"))
280    }
281
282    fn algorithm(&self) -> &str {
283        "aes-256-gcm"
284    }
285}
286
287pub struct ChaCha20Poly1305Encryption {
288    config: EncryptionConfig,
289}
290
291impl ChaCha20Poly1305Encryption {
292    pub fn new(config: EncryptionConfig) -> Result<Self> {
293        validate_aead_config(&config, "chacha20-poly1305")?;
294        Ok(Self { config })
295    }
296}
297
298impl EncryptionProvider for ChaCha20Poly1305Encryption {
299    fn encrypt(&self, plaintext: &[u8], password: &str) -> Result<EncryptedData> {
300        let salt = utils::random_bytes(self.config.salt_size)?;
301        let nonce = utils::random_bytes(self.config.nonce_size)?;
302        let key = derive_key(
303            password,
304            &salt,
305            &self.config.kdf,
306            self.config.kdf_iterations,
307            self.config.key_size,
308        )?;
309
310        let cipher = ChaCha20Poly1305::new_from_slice(&key)
311            .map_err(|_| anyhow!("invalid ChaCha20-Poly1305 key size"))?;
312        let nonce_ref = ChaChaNonce::from_slice(&nonce);
313
314        let ciphertext = cipher
315            .encrypt(nonce_ref, plaintext)
316            .map_err(|e| anyhow!("ChaCha20-Poly1305 encryption failed: {e}"))?;
317
318        Ok(EncryptedData {
319            version: 1,
320            algorithm: self.algorithm().to_string(),
321            kdf: self.config.kdf.clone(),
322            iterations: self.config.kdf_iterations,
323            salt,
324            nonce,
325            ciphertext,
326            created_at: now_unix(),
327            metadata: Default::default(),
328        })
329    }
330
331    fn decrypt(&self, encrypted: &EncryptedData, password: &str) -> Result<Vec<u8>> {
332        if encrypted.algorithm != self.algorithm() {
333            return Err(anyhow!(
334                "algorithm mismatch: provider={} payload={}",
335                self.algorithm(),
336                encrypted.algorithm
337            ));
338        }
339        if encrypted.nonce.len() != self.config.nonce_size {
340            return Err(anyhow!(
341                "invalid nonce size: expected {}, got {}",
342                self.config.nonce_size,
343                encrypted.nonce.len()
344            ));
345        }
346
347        let kdf = if encrypted.kdf.is_empty() {
348            &self.config.kdf
349        } else {
350            &encrypted.kdf
351        };
352        let iterations = if encrypted.iterations == 0 {
353            self.config.kdf_iterations
354        } else {
355            encrypted.iterations
356        };
357
358        let key = derive_key(
359            password,
360            &encrypted.salt,
361            kdf,
362            iterations,
363            self.config.key_size,
364        )?;
365
366        let cipher = ChaCha20Poly1305::new_from_slice(&key)
367            .map_err(|_| anyhow!("invalid ChaCha20-Poly1305 key size"))?;
368        let nonce_ref = ChaChaNonce::from_slice(&encrypted.nonce);
369
370        cipher
371            .decrypt(nonce_ref, encrypted.ciphertext.as_ref())
372            .map_err(|_| {
373                anyhow!("ChaCha20-Poly1305 authentication failed (wrong password or tampered data)")
374            })
375    }
376
377    fn algorithm(&self) -> &str {
378        "chacha20-poly1305"
379    }
380}
381
382fn validate_aead_config(config: &EncryptionConfig, expected_algorithm: &str) -> Result<()> {
383    if !config.enabled {
384        return Err(anyhow!("encryption provider requires enabled config"));
385    }
386    if config.key_size != 32 {
387        return Err(anyhow!("{expected_algorithm} requires 32-byte key"));
388    }
389    if config.nonce_size != 12 {
390        return Err(anyhow!("{expected_algorithm} requires 12-byte nonce"));
391    }
392    if config.salt_size < 16 {
393        return Err(anyhow!("salt_size must be at least 16 bytes"));
394    }
395    if config.kdf_iterations < 10_000 {
396        return Err(anyhow!(
397            "kdf_iterations must be at least 10,000 for production safety"
398        ));
399    }
400    Ok(())
401}
402
403fn derive_key(
404    password: &str,
405    salt: &[u8],
406    kdf: &str,
407    iterations: u32,
408    key_size: usize,
409) -> Result<Vec<u8>> {
410    if password.is_empty() {
411        return Err(anyhow!("password cannot be empty"));
412    }
413    if salt.is_empty() {
414        return Err(anyhow!("salt cannot be empty"));
415    }
416    if iterations == 0 {
417        return Err(anyhow!("kdf iterations must be > 0"));
418    }
419
420    match kdf.to_ascii_lowercase().as_str() {
421        "pbkdf2" => utils::derive_key_pbkdf2(password, salt, iterations, key_size),
422        other => Err(anyhow!("unsupported kdf '{}'; only pbkdf2 is supported", other)),
423    }
424}
425
426pub mod utils {
427    use super::*;
428
429    pub fn random_bytes(len: usize) -> Result<Vec<u8>> {
430        if len == 0 {
431            return Ok(Vec::new());
432        }
433        let mut output = vec![0u8; len];
434        getrandom::getrandom(&mut output)
435            .map_err(|e| anyhow!("secure random generation failed: {e}"))?;
436        Ok(output)
437    }
438
439    pub fn derive_key_pbkdf2(
440        password: &str,
441        salt: &[u8],
442        iterations: u32,
443        key_size: usize,
444    ) -> Result<Vec<u8>> {
445        if key_size == 0 {
446            return Err(anyhow!("key_size must be > 0"));
447        }
448        if iterations == 0 {
449            return Err(anyhow!("iterations must be > 0"));
450        }
451
452        let mut key = vec![0u8; key_size];
453        pbkdf2_hmac::<Sha256>(password.as_bytes(), salt, iterations, &mut key);
454        Ok(key)
455    }
456}
457
458fn now_unix() -> u64 {
459    SystemTime::now()
460        .duration_since(UNIX_EPOCH)
461        .unwrap_or_default()
462        .as_secs()
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    fn enabled_aes_config() -> EncryptionConfig {
470        EncryptionConfig {
471            enabled: true,
472            algorithm: "aes-256-gcm".to_string(),
473            ..EncryptionConfig::default()
474        }
475    }
476
477    fn enabled_chacha_config() -> EncryptionConfig {
478        EncryptionConfig {
479            enabled: true,
480            algorithm: "chacha20-poly1305".to_string(),
481            ..EncryptionConfig::default()
482        }
483    }
484
485    #[test]
486    fn encryption_config_default() {
487        let config = EncryptionConfig::default();
488        assert!(!config.enabled);
489        assert_eq!(config.algorithm, "aes-256-gcm");
490        assert_eq!(config.key_size, 32);
491        assert_eq!(config.nonce_size, 12);
492    }
493
494    #[test]
495    fn noop_encryption_roundtrip() {
496        let provider = Box::new(NoOpEncryption);
497        let plaintext = b"test data";
498        let encrypted = provider.encrypt(plaintext, "password").unwrap();
499        let decrypted = provider.decrypt(&encrypted, "password").unwrap();
500        assert_eq!(plaintext, &decrypted[..]);
501        assert_eq!(encrypted.algorithm, "none");
502    }
503
504    #[test]
505    fn encryption_manager_disabled() {
506        let manager = EncryptionManager::disabled();
507        assert!(!manager.is_enabled());
508
509        let plaintext = b"test data";
510        let encrypted = manager.encrypt(plaintext, "password").unwrap();
511        let decrypted = manager.decrypt(&encrypted, "password").unwrap();
512        assert_eq!(plaintext, &decrypted[..]);
513    }
514
515    #[test]
516    fn encryption_manager_from_config_selects_aes() {
517        let manager = EncryptionManager::from_config(enabled_aes_config()).unwrap();
518        assert!(manager.is_enabled());
519        assert_eq!(manager.config().algorithm, "aes-256-gcm");
520    }
521
522    #[test]
523    fn encryption_manager_from_config_selects_chacha() {
524        let manager = EncryptionManager::from_config(enabled_chacha_config()).unwrap();
525        assert!(manager.is_enabled());
526        assert_eq!(manager.config().algorithm, "chacha20-poly1305");
527    }
528
529    #[test]
530    fn encrypted_data_serialization() {
531        let encrypted = EncryptedData {
532            version: 1,
533            algorithm: "aes-256-gcm".to_string(),
534            kdf: "pbkdf2".to_string(),
535            iterations: 1000,
536            salt: vec![1, 2, 3, 4],
537            nonce: vec![5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
538            ciphertext: vec![9, 10, 11, 12],
539            created_at: 1,
540            metadata: Default::default(),
541        };
542
543        let json = serde_json::to_string(&encrypted).unwrap();
544        let deserialized: EncryptedData = serde_json::from_str(&json).unwrap();
545
546        assert_eq!(encrypted.version, deserialized.version);
547        assert_eq!(encrypted.algorithm, deserialized.algorithm);
548        assert_eq!(encrypted.salt, deserialized.salt);
549    }
550
551    #[test]
552    fn aes_encryption_roundtrip() {
553        let provider = Aes256GcmEncryption::new(enabled_aes_config()).unwrap();
554        let plaintext = b"secret data";
555        let encrypted = provider.encrypt(plaintext, "strong-password").unwrap();
556        let decrypted = provider.decrypt(&encrypted, "strong-password").unwrap();
557        assert_eq!(plaintext, &decrypted[..]);
558        assert_ne!(encrypted.ciphertext, plaintext);
559    }
560
561    #[test]
562    fn aes_wrong_password_fails() {
563        let provider = Aes256GcmEncryption::new(enabled_aes_config()).unwrap();
564        let encrypted = provider.encrypt(b"secret", "correct-password").unwrap();
565        assert!(provider.decrypt(&encrypted, "wrong-password").is_err());
566    }
567
568    #[test]
569    fn chacha_encryption_roundtrip() {
570        let provider = ChaCha20Poly1305Encryption::new(enabled_chacha_config()).unwrap();
571        let plaintext = b"secret data";
572        let encrypted = provider.encrypt(plaintext, "strong-password").unwrap();
573        let decrypted = provider.decrypt(&encrypted, "strong-password").unwrap();
574        assert_eq!(plaintext, &decrypted[..]);
575        assert_ne!(encrypted.ciphertext, plaintext);
576    }
577
578    #[test]
579    fn chacha_wrong_password_fails() {
580        let provider = ChaCha20Poly1305Encryption::new(enabled_chacha_config()).unwrap();
581        let encrypted = provider.encrypt(b"secret", "correct-password").unwrap();
582        assert!(provider.decrypt(&encrypted, "wrong-password").is_err());
583    }
584
585    #[test]
586    fn pbkdf2_key_derivation() {
587        let password = "super_secret";
588        let salt = b"random_salt";
589        let key1 = utils::derive_key_pbkdf2(password, salt, 1000, 32).unwrap();
590        let key2 = utils::derive_key_pbkdf2(password, salt, 1000, 32).unwrap();
591
592        assert_eq!(key1, key2);
593        assert_eq!(key1.len(), 32);
594    }
595
596    #[test]
597    fn random_bytes_generation() {
598        let bytes1 = utils::random_bytes(16).unwrap();
599        let bytes2 = utils::random_bytes(16).unwrap();
600
601        assert_eq!(bytes1.len(), 16);
602        assert_eq!(bytes2.len(), 16);
603        assert_ne!(bytes1, bytes2);
604    }
605
606    #[test]
607    fn invalid_kdf_rejected() {
608        let mut config = enabled_aes_config();
609        config.kdf = "unsupported".into();
610        let provider = Aes256GcmEncryption::new(config).unwrap();
611        assert!(provider.encrypt(b"x", "pw").is_err());
612    }
613
614    #[test]
615    fn weak_config_is_rejected() {
616        let mut config = enabled_aes_config();
617        config.kdf_iterations = 100;
618        assert!(Aes256GcmEncryption::new(config).is_err());
619    }
620}