1use aes_gcm::aead::{Aead, KeyInit};
9use aes_gcm::{Aes256Gcm, Key, Nonce};
10use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
11use hmac::{Hmac, Mac};
12use sha2::Sha256;
13
14type HmacSha256 = Hmac<Sha256>;
15
16pub struct ConfigEncryptor {
19 key: [u8; 32],
20}
21
22impl ConfigEncryptor {
23 pub fn new(password: &str) -> Self {
28 let key = derive_key(password);
29 Self { key }
30 }
31
32 pub fn with_key(key: [u8; 32]) -> Self {
35 Self { key }
36 }
37
38 pub fn encrypt(&self, plaintext: &str) -> Result<String, EncryptError> {
41 let key = Key::<Aes256Gcm>::from_slice(&self.key);
42 let cipher = Aes256Gcm::new(key);
43 let nonce_bytes: [u8; 12] = rand::random();
44 let nonce = Nonce::from(nonce_bytes);
45 let ciphertext = cipher
46 .encrypt(&nonce, plaintext.as_bytes())
47 .map_err(|_| EncryptError::EncryptionFailed)?;
48
49 let mut combined = Vec::with_capacity(12 + ciphertext.len());
50 combined.extend_from_slice(&nonce_bytes);
51 combined.extend_from_slice(&ciphertext);
52
53 Ok(format!("ENC({})", BASE64.encode(&combined)))
54 }
55
56 pub fn decrypt(&self, encrypted: &str) -> Result<String, EncryptError> {
59 let payload = if let Some(inner) = extract_enc_value(encrypted) {
60 inner
61 } else {
62 encrypted
63 };
64
65 let combined = BASE64
66 .decode(payload)
67 .map_err(|e| EncryptError::Base64Error(e.to_string()))?;
68
69 if combined.len() < 13 {
70 return Err(EncryptError::InvalidPayload);
71 }
72
73 let (nonce_bytes, ciphertext) = combined.split_at(12);
74 let nonce = Nonce::from_slice(nonce_bytes);
75
76 let key = Key::<Aes256Gcm>::from_slice(&self.key);
77 let cipher = Aes256Gcm::new(key);
78 let plaintext = cipher
79 .decrypt(nonce, ciphertext)
80 .map_err(|_| EncryptError::DecryptionFailed)?;
81
82 String::from_utf8(plaintext).map_err(|e| EncryptError::Utf8Error(e.to_string()))
83 }
84
85 pub fn is_encrypted(value: &str) -> bool {
88 extract_enc_value(value).is_some()
89 }
90
91 pub fn maybe_decrypt(&self, value: &str) -> Result<String, EncryptError> {
94 if Self::is_encrypted(value) {
95 self.decrypt(value)
96 } else {
97 Ok(value.to_string())
98 }
99 }
100
101 pub fn decrypt_json_value(&self, value: &mut serde_json::Value) -> Result<(), EncryptError> {
104 match value {
105 serde_json::Value::String(s) if Self::is_encrypted(s) => {
106 *s = self.decrypt(s)?;
107 },
108 serde_json::Value::Object(map) => {
109 for v in map.values_mut() {
110 self.decrypt_json_value(v)?;
111 }
112 },
113 serde_json::Value::Array(arr) => {
114 for v in arr.iter_mut() {
115 self.decrypt_json_value(v)?;
116 }
117 },
118 _ => {},
119 }
120 Ok(())
121 }
122}
123
124fn derive_key(password: &str) -> [u8; 32] {
127 let mut mac =
128 <HmacSha256 as Mac>::new_from_slice(b"hiver-config-encryptor").expect("HMAC key is valid");
129 mac.update(password.as_bytes());
130 let result = mac.finalize().into_bytes();
131
132 let mut key = [0u8; 32];
133 key.copy_from_slice(&result);
134 key
135}
136
137fn extract_enc_value(value: &str) -> Option<&str> {
140 let trimmed = value.trim();
141 if trimmed.starts_with("ENC(") && trimmed.ends_with(')') {
142 Some(&trimmed[4..trimmed.len() - 1])
143 } else {
144 None
145 }
146}
147
148#[derive(Debug, thiserror::Error)]
151pub enum EncryptError {
152 #[error("Invalid encryption key")]
154 InvalidKey,
155 #[error("Encryption failed")]
157 EncryptionFailed,
158 #[error("Decryption failed (wrong password or corrupted data)")]
160 DecryptionFailed,
161 #[error("Invalid payload (too short)")]
163 InvalidPayload,
164 #[error("Base64 error: {0}")]
166 Base64Error(String),
167 #[error("UTF-8 error: {0}")]
169 Utf8Error(String),
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_encrypt_decrypt_roundtrip() {
178 let enc = ConfigEncryptor::new("my-secret-password");
179 let original = "database-password-123";
180 let encrypted = enc.encrypt(original).unwrap();
181
182 assert!(encrypted.starts_with("ENC("));
183 assert!(encrypted.ends_with(')'));
184 assert_ne!(encrypted, original);
185
186 let decrypted = enc.decrypt(&encrypted).unwrap();
187 assert_eq!(decrypted, original);
188 }
189
190 #[test]
191 fn test_encrypt_produces_different_ciphertexts() {
192 let enc = ConfigEncryptor::new("password");
193 let encrypted1 = enc.encrypt("same-value").unwrap();
194 let encrypted2 = enc.encrypt("same-value").unwrap();
195 assert_ne!(encrypted1, encrypted2);
196 }
197
198 #[test]
199 fn test_wrong_password_fails() {
200 let enc1 = ConfigEncryptor::new("correct-password");
201 let enc2 = ConfigEncryptor::new("wrong-password");
202 let encrypted = enc1.encrypt("secret").unwrap();
203 assert!(enc2.decrypt(&encrypted).is_err());
204 }
205
206 #[test]
207 fn test_is_encrypted() {
208 assert!(ConfigEncryptor::is_encrypted("ENC(abc123)"));
209 assert!(ConfigEncryptor::is_encrypted(" ENC(abc123) "));
210 assert!(!ConfigEncryptor::is_encrypted("plain-text"));
211 assert!(!ConfigEncryptor::is_encrypted("ENC("));
212 }
213
214 #[test]
215 fn test_maybe_decrypt() {
216 let enc = ConfigEncryptor::new("pass");
217 let encrypted = enc.encrypt("secret").unwrap();
218
219 assert_eq!(enc.maybe_decrypt(&encrypted).unwrap(), "secret");
220 assert_eq!(enc.maybe_decrypt("plain").unwrap(), "plain");
221 }
222
223 #[test]
224 fn test_decrypt_json_value() {
225 let enc = ConfigEncryptor::new("pass");
226 let enc_db = enc.encrypt("db-password").unwrap();
227 let enc_api = enc.encrypt("api-key").unwrap();
228
229 let mut json = serde_json::json!({
230 "database": {
231 "url": "postgres://localhost:5432/mydb",
232 "password": enc_db,
233 },
234 "api_key": enc_api,
235 "timeout": 30,
236 "names": ["alice", "bob"],
237 });
238
239 enc.decrypt_json_value(&mut json).unwrap();
240
241 assert_eq!(json["database"]["password"], "db-password");
242 assert_eq!(json["api_key"], "api-key");
243 assert_eq!(json["database"]["url"], "postgres://localhost:5432/mydb");
244 assert_eq!(json["timeout"], 30);
245 }
246
247 #[test]
248 fn test_with_raw_key() {
249 let key = [42u8; 32];
250 let enc = ConfigEncryptor::with_key(key);
251 let encrypted = enc.encrypt("test").unwrap();
252 let decrypted = enc.decrypt(&encrypted).unwrap();
253 assert_eq!(decrypted, "test");
254 }
255}