1use crate::types::EncryptionInfo;
9use aes_gcm::{
10 aead::{Aead, KeyInit},
11 Aes256Gcm, Key, Nonce,
12};
13use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
14use hkdf::Hkdf;
15use getrandom::getrandom;
16use sha2::Sha256;
17use std::{fs, io::Write, path::Path};
18use thiserror::Error;
19use x25519_dalek::{x25519, PublicKey};
20use zeroize::Zeroizing;
21
22#[derive(Debug, Error)]
24pub enum EncryptionError {
25 #[error("Failed to generate random bytes: {0}")]
26 RandomGeneration(String),
27
28 #[error("Failed to encrypt data: {0}")]
29 EncryptionFailed(String),
30
31 #[error("Failed to decrypt data: {0}")]
32 DecryptionFailed(String),
33
34 #[error("Invalid key format: {0}")]
35 InvalidKey(String),
36
37 #[error("IO error: {0}")]
38 Io(#[from] std::io::Error),
39
40 #[error("Base64 decode error: {0}")]
41 Base64Decode(#[from] base64::DecodeError),
42
43 #[error("Invalid nonce size: expected 12 bytes, got {0}")]
44 InvalidNonceSize(usize),
45
46 #[error("Invalid auth tag size: expected 16 bytes, got {0}")]
47 InvalidAuthTagSize(usize),
48}
49
50pub struct EncryptionManager {
52 recipient_public_key: PublicKey,
54}
55
56#[derive(Debug)]
58pub struct EncryptionResult {
59 pub ciphertext: Vec<u8>,
61 pub info: EncryptionInfo,
63}
64
65pub struct KeyPair {
68 pub secret: Zeroizing<[u8; 32]>,
70 pub public: PublicKey,
72}
73
74impl EncryptionManager {
75 pub fn new(recipient_public_key: PublicKey) -> Self {
77 Self {
78 recipient_public_key,
79 }
80 }
81
82 pub fn from_base64_public_key(public_key_b64: &str) -> Result<Self, EncryptionError> {
84 let key_bytes = BASE64.decode(public_key_b64)?;
85 if key_bytes.len() != 32 {
86 return Err(EncryptionError::InvalidKey(format!(
87 "Expected 32 bytes, got {}",
88 key_bytes.len()
89 )));
90 }
91
92 let mut key_array = [0u8; 32];
93 key_array.copy_from_slice(&key_bytes);
94 let public_key = PublicKey::from(key_array);
95
96 Ok(Self::new(public_key))
97 }
98
99 pub fn from_public_key_file<P: AsRef<Path>>(path: P) -> Result<Self, EncryptionError> {
101 let key_bytes = fs::read(path)?;
102 if key_bytes.len() != 32 {
103 return Err(EncryptionError::InvalidKey(format!(
104 "Expected 32 bytes in key file, got {}",
105 key_bytes.len()
106 )));
107 }
108
109 let mut key_array = [0u8; 32];
110 key_array.copy_from_slice(&key_bytes);
111 let public_key = PublicKey::from(key_array);
112
113 Ok(Self::new(public_key))
114 }
115
116 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptionResult, EncryptionError> {
118 let mut eph_secret = [0u8; 32];
120 getrandom(&mut eph_secret)
121 .map_err(|e| EncryptionError::RandomGeneration(e.to_string()))?;
122
123 let mut base_point = [0u8; 32];
125 base_point[0] = 9;
126 let eph_public_bytes = x25519(eph_secret, base_point);
127 let ephemeral_public = PublicKey::from(eph_public_bytes);
128
129 let shared_secret_bytes = x25519(eph_secret, self.recipient_public_key.to_bytes());
131
132 let mut z = Zeroizing::new(eph_secret);
134 for b in z.iter_mut() { *b = 0; }
135
136 let hkdf = Hkdf::<Sha256>::new(None, &shared_secret_bytes);
138 let mut symmetric_key = [0u8; 32]; hkdf.expand(b"JMIX-AES256-GCM", &mut symmetric_key)
140 .map_err(|e| {
141 EncryptionError::EncryptionFailed(format!("HKDF expansion failed: {}", e))
142 })?;
143
144 let mut iv = [0u8; 12];
146 getrandom(&mut iv).map_err(|e| EncryptionError::RandomGeneration(e.to_string()))?;
147
148 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&symmetric_key));
150 let nonce = Nonce::from_slice(&iv);
151
152 let ciphertext = cipher.encrypt(nonce, plaintext).map_err(|e| {
153 EncryptionError::EncryptionFailed(format!("AES-GCM encryption failed: {}", e))
154 })?;
155
156 if ciphertext.len() < 16 {
158 return Err(EncryptionError::EncryptionFailed(
159 "Ciphertext too short".to_string(),
160 ));
161 }
162
163 let (data, auth_tag) = ciphertext.split_at(ciphertext.len() - 16);
164
165 let info = EncryptionInfo {
167 algorithm: "AES-256-GCM".to_string(),
168 ephemeral_public_key: BASE64.encode(ephemeral_public.as_bytes()),
169 iv: BASE64.encode(&iv),
170 auth_tag: BASE64.encode(auth_tag),
171 };
172
173 Ok(EncryptionResult {
174 ciphertext: data.to_vec(),
175 info,
176 })
177 }
178}
179
180pub struct DecryptionManager {
183 secret_key: Zeroizing<[u8; 32]>,
185}
186
187impl DecryptionManager {
188 pub fn new(secret_key: [u8; 32]) -> Self {
190 Self {
191 secret_key: Zeroizing::new(secret_key),
192 }
193 }
194
195 pub fn from_bytes(key_bytes: [u8; 32]) -> Self {
197 Self::new(key_bytes)
198 }
199
200 pub fn from_secret_key_file<P: AsRef<Path>>(path: P) -> Result<Self, EncryptionError> {
202 let key_bytes = fs::read(path)?;
203 if key_bytes.len() != 32 {
204 return Err(EncryptionError::InvalidKey(format!(
205 "Expected 32 bytes in key file, got {}",
206 key_bytes.len()
207 )));
208 }
209
210 let mut key_array = [0u8; 32];
211 key_array.copy_from_slice(&key_bytes);
212
213 Ok(Self::from_bytes(key_array))
214 }
215
216 pub fn decrypt(
218 &self,
219 ciphertext: &[u8],
220 info: &EncryptionInfo,
221 ) -> Result<Vec<u8>, EncryptionError> {
222 if info.algorithm != "AES-256-GCM" {
224 return Err(EncryptionError::DecryptionFailed(format!(
225 "Unsupported algorithm: {}",
226 info.algorithm
227 )));
228 }
229
230 let ephemeral_public_bytes = BASE64.decode(&info.ephemeral_public_key)?;
232 if ephemeral_public_bytes.len() != 32 {
233 return Err(EncryptionError::InvalidKey(format!(
234 "Invalid ephemeral public key length: {}",
235 ephemeral_public_bytes.len()
236 )));
237 }
238
239 let mut key_array = [0u8; 32];
240 key_array.copy_from_slice(&ephemeral_public_bytes);
241 let ephemeral_public = PublicKey::from(key_array);
242
243 let iv_bytes = BASE64.decode(&info.iv)?;
245 let auth_tag_bytes = BASE64.decode(&info.auth_tag)?;
246
247 if iv_bytes.len() != 12 {
248 return Err(EncryptionError::InvalidNonceSize(iv_bytes.len()));
249 }
250
251 if auth_tag_bytes.len() != 16 {
252 return Err(EncryptionError::InvalidAuthTagSize(auth_tag_bytes.len()));
253 }
254
255 let shared_secret_bytes = x25519(*self.secret_key, ephemeral_public.to_bytes());
258
259 let hkdf = Hkdf::<Sha256>::new(None, &shared_secret_bytes);
261 let mut symmetric_key = [0u8; 32];
262 hkdf.expand(b"JMIX-AES256-GCM", &mut symmetric_key)
263 .map_err(|e| {
264 EncryptionError::DecryptionFailed(format!("HKDF expansion failed: {}", e))
265 })?;
266
267 let mut full_ciphertext = ciphertext.to_vec();
269 full_ciphertext.extend_from_slice(&auth_tag_bytes);
270
271 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&symmetric_key));
273 let nonce = Nonce::from_slice(&iv_bytes);
274
275 let plaintext = cipher
276 .decrypt(nonce, full_ciphertext.as_slice())
277 .map_err(|e| {
278 EncryptionError::DecryptionFailed(format!("AES-GCM decryption failed: {}", e))
279 })?;
280
281 Ok(plaintext)
282 }
283}
284
285impl KeyPair {
286 pub fn generate() -> Self {
288 let mut secret = [0u8; 32];
289 getrandom(&mut secret).expect("OS RNG unavailable");
290
291 let mut base_point = [0u8; 32];
296 base_point[0] = 9;
297 let public_bytes = x25519(secret, base_point);
298 let public = PublicKey::from(public_bytes);
299
300 Self {
301 secret: Zeroizing::new(secret),
302 public,
303 }
304 }
305
306 pub fn from_secret_bytes(secret_bytes: [u8; 32]) -> Self {
308 let mut base_point = [0u8; 32];
312 base_point[0] = 9;
313 let public_bytes = x25519(secret_bytes, base_point);
314 let public = PublicKey::from(public_bytes);
315 Self {
316 secret: Zeroizing::new(secret_bytes),
317 public,
318 }
319 }
320
321 pub fn secret_bytes(&self) -> [u8; 32] {
323 *self.secret
324 }
325
326 pub fn public_bytes(&self) -> [u8; 32] {
328 self.public.to_bytes()
329 }
330
331 pub fn public_key_base64(&self) -> String {
333 BASE64.encode(self.public.as_bytes())
334 }
335
336 pub fn save_to_files<P: AsRef<Path>>(
338 &self,
339 secret_path: P,
340 public_path: P,
341 ) -> Result<(), EncryptionError> {
342 let mut secret_file = fs::File::create(secret_path)?;
344 secret_file.write_all(&self.secret_bytes())?;
345
346 let mut public_file = fs::File::create(public_path)?;
348 public_file.write_all(&self.public_bytes())?;
349
350 Ok(())
351 }
352
353 pub fn load_from_secret_file<P: AsRef<Path>>(secret_path: P) -> Result<Self, EncryptionError> {
355 let secret_bytes = fs::read(secret_path)?;
356 if secret_bytes.len() != 32 {
357 return Err(EncryptionError::InvalidKey(format!(
358 "Expected 32 bytes in secret key file, got {}",
359 secret_bytes.len()
360 )));
361 }
362
363 let mut key_array = [0u8; 32];
364 key_array.copy_from_slice(&secret_bytes);
365
366 Ok(Self::from_secret_bytes(key_array))
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use tempfile::TempDir;
374
375 #[test]
376 fn test_keypair_generation() {
377 let keypair = KeyPair::generate();
378
379 assert_eq!(keypair.secret_bytes().len(), 32);
381 assert_eq!(keypair.public_bytes().len(), 32);
382
383 let public_b64 = keypair.public_key_base64();
385 assert!(!public_b64.is_empty());
386 }
387
388 #[test]
389 fn test_keypair_save_load() -> Result<(), Box<dyn std::error::Error>> {
390 let temp_dir = TempDir::new()?;
391 let secret_path = temp_dir.path().join("secret.key");
392 let public_path = temp_dir.path().join("public.key");
393
394 let original_keypair = KeyPair::generate();
396 original_keypair.save_to_files(&secret_path, &public_path)?;
397
398 let loaded_keypair = KeyPair::load_from_secret_file(&secret_path)?;
400
401 assert_eq!(
403 original_keypair.secret_bytes(),
404 loaded_keypair.secret_bytes()
405 );
406 assert_eq!(
407 original_keypair.public_bytes(),
408 loaded_keypair.public_bytes()
409 );
410
411 Ok(())
412 }
413
414 #[test]
415 fn test_encryption_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
416 let recipient_keypair = KeyPair::generate();
418
419 let encryption_manager = EncryptionManager::new(recipient_keypair.public);
421 let decryption_manager = DecryptionManager::new(*recipient_keypair.secret);
422
423 let plaintext = b"Hello, JMIX encryption!";
425
426 let result = encryption_manager.encrypt(plaintext)?;
428 assert!(result.ciphertext.len() > 0);
429 assert_eq!(result.info.algorithm, "AES-256-GCM");
430
431 let decrypted = decryption_manager.decrypt(&result.ciphertext, &result.info)?;
433 assert_eq!(decrypted, plaintext);
434
435 Ok(())
436 }
437
438 #[test]
439 fn test_encryption_manager_from_base64() -> Result<(), Box<dyn std::error::Error>> {
440 let keypair = KeyPair::generate();
441 let public_b64 = keypair.public_key_base64();
442
443 let manager = EncryptionManager::from_base64_public_key(&public_b64)?;
444
445 let plaintext = b"Test message";
447 let result = manager.encrypt(plaintext)?;
448 assert!(result.ciphertext.len() > 0);
449
450 Ok(())
451 }
452
453 #[test]
454 fn test_encryption_different_ephemeral_keys() -> Result<(), Box<dyn std::error::Error>> {
455 let recipient_keypair = KeyPair::generate();
456 let encryption_manager = EncryptionManager::new(recipient_keypair.public);
457
458 let plaintext = b"Same message";
459
460 let result1 = encryption_manager.encrypt(plaintext)?;
462 let result2 = encryption_manager.encrypt(plaintext)?;
463
464 assert_ne!(
466 result1.info.ephemeral_public_key,
467 result2.info.ephemeral_public_key
468 );
469 assert_ne!(result1.info.iv, result2.info.iv);
470 assert_ne!(result1.ciphertext, result2.ciphertext);
471
472 Ok(())
473 }
474
475 #[test]
476 fn test_invalid_decryption() -> Result<(), Box<dyn std::error::Error>> {
477 let recipient_keypair = KeyPair::generate();
478 let wrong_keypair = KeyPair::generate(); let encryption_manager = EncryptionManager::new(recipient_keypair.public);
481 let wrong_decryption_manager = DecryptionManager::new(*wrong_keypair.secret);
482
483 let plaintext = b"Secret message";
484 let result = encryption_manager.encrypt(plaintext)?;
485
486 let decrypt_result = wrong_decryption_manager.decrypt(&result.ciphertext, &result.info);
488 assert!(decrypt_result.is_err());
489
490 Ok(())
491 }
492}