1use crate::error::{PachaError, Result};
31use serde::{Deserialize, Serialize};
32
33const MAGIC: &[u8; 8] = b"PACHAENC";
35
36const VERSION: u8 = 1;
38
39const SALT_LEN: usize = 32;
41
42const NONCE_LEN: usize = 12;
44
45const TAG_LEN: usize = 16;
47
48const HEADER_SIZE: usize = 8 + 1 + SALT_LEN + NONCE_LEN;
50
51#[derive(Debug, Clone)]
53pub struct EncryptedHeader {
54 pub version: u8,
56 pub salt: [u8; SALT_LEN],
58 pub nonce: [u8; NONCE_LEN],
60}
61
62impl EncryptedHeader {
63 #[must_use]
65 pub fn new() -> Self {
66 #[cfg(feature = "encryption")]
67 {
68 use rand::rngs::OsRng;
69 use rand::RngCore;
70 let mut salt = [0u8; SALT_LEN];
71 let mut nonce = [0u8; NONCE_LEN];
72 OsRng.fill_bytes(&mut salt);
73 OsRng.fill_bytes(&mut nonce);
74 Self { version: VERSION, salt, nonce }
75 }
76 #[cfg(not(feature = "encryption"))]
77 {
78 let seed = std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .map(|d| d.as_nanos())
82 .unwrap_or(0);
83
84 let mut salt = [0u8; SALT_LEN];
85 let mut nonce = [0u8; NONCE_LEN];
86
87 for (i, byte) in salt.iter_mut().enumerate() {
88 *byte = ((seed >> (i % 16)) ^ (i as u128 * 7)) as u8;
89 }
90 for (i, byte) in nonce.iter_mut().enumerate() {
91 *byte = ((seed >> ((i + 32) % 16)) ^ (i as u128 * 13)) as u8;
92 }
93
94 Self { version: VERSION, salt, nonce }
95 }
96 }
97
98 #[must_use]
100 pub fn to_bytes(&self) -> Vec<u8> {
101 let mut bytes = Vec::with_capacity(HEADER_SIZE);
102 bytes.extend_from_slice(MAGIC);
103 bytes.push(self.version);
104 bytes.extend_from_slice(&self.salt);
105 bytes.extend_from_slice(&self.nonce);
106 bytes
107 }
108
109 pub fn from_bytes(data: &[u8]) -> Result<Self> {
111 if data.len() < HEADER_SIZE {
112 return Err(PachaError::InvalidFormat("encrypted file too short".to_string()));
113 }
114
115 if &data[0..8] != MAGIC {
117 return Err(PachaError::InvalidFormat("not an encrypted pacha file".to_string()));
118 }
119
120 let version = data[8];
121 if version != VERSION {
122 return Err(PachaError::InvalidFormat(format!(
123 "unsupported encryption version: {}",
124 version
125 )));
126 }
127
128 let mut salt = [0u8; SALT_LEN];
129 salt.copy_from_slice(&data[9..9 + SALT_LEN]);
130
131 let mut nonce = [0u8; NONCE_LEN];
132 nonce.copy_from_slice(&data[9 + SALT_LEN..HEADER_SIZE]);
133
134 Ok(Self { version, salt, nonce })
135 }
136}
137
138impl Default for EncryptedHeader {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct EncryptionConfig {
147 pub memory_cost_kib: u32,
149 pub time_cost: u32,
151 pub parallelism: u32,
153}
154
155impl Default for EncryptionConfig {
156 fn default() -> Self {
157 Self {
158 memory_cost_kib: 65536, time_cost: 3,
160 parallelism: 4,
161 }
162 }
163}
164
165#[cfg(feature = "encryption")]
167fn derive_key(
168 password: &str,
169 salt: &[u8; SALT_LEN],
170 config: &EncryptionConfig,
171) -> Result<[u8; 32]> {
172 use argon2::{Algorithm, Argon2, Params, Version};
173
174 let params =
175 Params::new(config.memory_cost_kib, config.time_cost, config.parallelism, Some(32))
176 .map_err(|e| PachaError::Validation(format!("Invalid Argon2 params: {e}")))?;
177
178 let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
179
180 let mut key = [0u8; 32];
181 argon2
182 .hash_password_into(password.as_bytes(), salt, &mut key)
183 .map_err(|e| PachaError::Validation(format!("Key derivation failed: {e}")))?;
184
185 Ok(key)
186}
187
188#[cfg(not(feature = "encryption"))]
190fn derive_key(
191 password: &str,
192 salt: &[u8; SALT_LEN],
193 _config: &EncryptionConfig,
194) -> Result<[u8; 32]> {
195 let mut key = [0u8; 32];
197 let mut state = [0u8; 64];
198
199 for (i, &b) in password.as_bytes().iter().enumerate() {
200 state[i % 64] ^= b;
201 }
202 for (i, &b) in salt.iter().enumerate() {
203 state[(i + 32) % 64] ^= b;
204 }
205
206 for iteration in 0..10000u32 {
207 let iter_bytes = iteration.to_le_bytes();
208 for (i, &b) in iter_bytes.iter().enumerate() {
209 state[i] ^= b;
210 }
211 for i in 0..64 {
212 state[i] = state[i].wrapping_add(state[(i + 1) % 64]).wrapping_mul(33);
213 }
214 }
215
216 key.copy_from_slice(&state[0..32]);
217 Ok(key)
218}
219
220#[cfg(feature = "encryption")]
222fn chacha_encrypt(data: &[u8], key: &[u8; 32], nonce: &[u8; NONCE_LEN]) -> Result<Vec<u8>> {
223 use chacha20poly1305::{
224 aead::{Aead, KeyInit},
225 ChaCha20Poly1305, Nonce,
226 };
227
228 let cipher = ChaCha20Poly1305::new_from_slice(key)
229 .map_err(|e| PachaError::Validation(format!("Invalid key: {e}")))?;
230
231 let nonce = Nonce::from_slice(nonce);
232
233 cipher
234 .encrypt(nonce, data)
235 .map_err(|e| PachaError::Validation(format!("Encryption failed: {e}")))
236}
237
238#[cfg(feature = "encryption")]
240fn chacha_decrypt(ciphertext: &[u8], key: &[u8; 32], nonce: &[u8; NONCE_LEN]) -> Result<Vec<u8>> {
241 use chacha20poly1305::{
242 aead::{Aead, KeyInit},
243 ChaCha20Poly1305, Nonce,
244 };
245
246 let cipher = ChaCha20Poly1305::new_from_slice(key)
247 .map_err(|e| PachaError::Validation(format!("Invalid key: {e}")))?;
248
249 let nonce = Nonce::from_slice(nonce);
250
251 cipher.decrypt(nonce, ciphertext).map_err(|_| {
252 PachaError::InvalidFormat(
253 "decryption failed: invalid password or corrupted data".to_string(),
254 )
255 })
256}
257
258#[cfg(not(feature = "encryption"))]
260fn chacha_encrypt(data: &[u8], key: &[u8; 32], nonce: &[u8; NONCE_LEN]) -> Result<Vec<u8>> {
261 let mut output = data.to_vec();
262 let mut keystream = [0u8; 64];
263
264 for (block_idx, chunk) in output.chunks_mut(64).enumerate() {
265 for (i, ks) in keystream.iter_mut().enumerate() {
266 *ks = key[i % 32]
267 .wrapping_add(nonce[i % NONCE_LEN])
268 .wrapping_add(block_idx as u8)
269 .wrapping_mul(i as u8 + 1);
270 }
271 for (i, byte) in chunk.iter_mut().enumerate() {
272 *byte ^= keystream[i];
273 }
274 }
275
276 let tag = compute_fallback_tag(&output, key);
278 output.extend_from_slice(&tag);
279
280 Ok(output)
281}
282
283#[cfg(not(feature = "encryption"))]
284fn chacha_decrypt(ciphertext: &[u8], key: &[u8; 32], nonce: &[u8; NONCE_LEN]) -> Result<Vec<u8>> {
285 if ciphertext.len() < TAG_LEN {
286 return Err(PachaError::InvalidFormat("ciphertext too short".to_string()));
287 }
288
289 let data = &ciphertext[..ciphertext.len() - TAG_LEN];
290 let stored_tag = &ciphertext[ciphertext.len() - TAG_LEN..];
291
292 let computed_tag = compute_fallback_tag(data, key);
294 if computed_tag != stored_tag {
295 return Err(PachaError::InvalidFormat(
296 "decryption failed: invalid password or corrupted data".to_string(),
297 ));
298 }
299
300 let mut output = data.to_vec();
302 let mut keystream = [0u8; 64];
303
304 for (block_idx, chunk) in output.chunks_mut(64).enumerate() {
305 for (i, ks) in keystream.iter_mut().enumerate() {
306 *ks = key[i % 32]
307 .wrapping_add(nonce[i % NONCE_LEN])
308 .wrapping_add(block_idx as u8)
309 .wrapping_mul(i as u8 + 1);
310 }
311 for (i, byte) in chunk.iter_mut().enumerate() {
312 *byte ^= keystream[i];
313 }
314 }
315
316 Ok(output)
317}
318
319#[cfg(not(feature = "encryption"))]
320fn compute_fallback_tag(ciphertext: &[u8], key: &[u8; 32]) -> [u8; TAG_LEN] {
321 let mut tag = [0u8; TAG_LEN];
322 let mut state = [0u64; 4];
323
324 for (i, &b) in key.iter().enumerate() {
325 state[i % 4] ^= (b as u64) << ((i * 8) % 64);
326 }
327
328 for (i, &b) in ciphertext.iter().enumerate() {
329 state[i % 4] = state[i % 4].wrapping_add(b as u64).wrapping_mul(0x100000001b3);
330 }
331
332 for (i, byte) in tag.iter_mut().enumerate() {
333 *byte = (state[i % 4] >> ((i % 8) * 8)) as u8;
334 }
335
336 tag
337}
338
339pub fn encrypt_model(data: &[u8], password: &str) -> Result<Vec<u8>> {
344 encrypt_model_with_config(data, password, &EncryptionConfig::default())
345}
346
347pub fn encrypt_model_with_config(
349 data: &[u8],
350 password: &str,
351 config: &EncryptionConfig,
352) -> Result<Vec<u8>> {
353 if password.is_empty() {
354 return Err(PachaError::InvalidFormat("encryption password cannot be empty".to_string()));
355 }
356
357 let header = EncryptedHeader::new();
358 let key = derive_key(password, &header.salt, config)?;
359
360 let ciphertext = chacha_encrypt(data, &key, &header.nonce)?;
362
363 let mut output = header.to_bytes();
365 output.extend_from_slice(&ciphertext);
366
367 Ok(output)
368}
369
370pub fn decrypt_model(encrypted_data: &[u8], password: &str) -> Result<Vec<u8>> {
372 decrypt_model_with_config(encrypted_data, password, &EncryptionConfig::default())
373}
374
375pub fn decrypt_model_with_config(
377 encrypted_data: &[u8],
378 password: &str,
379 config: &EncryptionConfig,
380) -> Result<Vec<u8>> {
381 if encrypted_data.len() < HEADER_SIZE + TAG_LEN {
382 return Err(PachaError::InvalidFormat("encrypted data too short".to_string()));
383 }
384
385 let header = EncryptedHeader::from_bytes(encrypted_data)?;
387
388 let ciphertext = &encrypted_data[HEADER_SIZE..];
390
391 let key = derive_key(password, &header.salt, config)?;
393
394 chacha_decrypt(ciphertext, &key, &header.nonce)
396}
397
398#[must_use]
400pub fn is_encrypted(data: &[u8]) -> bool {
401 data.len() >= 8 && &data[0..8] == MAGIC
402}
403
404pub fn get_version(data: &[u8]) -> Result<u8> {
406 if data.len() < 9 {
407 return Err(PachaError::InvalidFormat("data too short for version check".to_string()));
408 }
409 if &data[0..8] != MAGIC {
410 return Err(PachaError::InvalidFormat("not an encrypted pacha file".to_string()));
411 }
412 Ok(data[8])
413}
414
415#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
428 fn test_encrypt_decrypt_roundtrip() {
429 let original = b"Hello, this is test model data!";
430 let password = "my-secret-password";
431
432 let encrypted = encrypt_model(original, password).unwrap();
433 let decrypted = decrypt_model(&encrypted, password).unwrap();
434
435 assert_eq!(original.as_slice(), decrypted.as_slice());
436 }
437
438 #[test]
439 fn test_encrypt_decrypt_large_data() {
440 let original: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
441 let password = "test-password-123";
442
443 let encrypted = encrypt_model(&original, password).unwrap();
444 let decrypted = decrypt_model(&encrypted, password).unwrap();
445
446 assert_eq!(original, decrypted);
447 }
448
449 #[test]
450 fn test_encrypt_decrypt_1mb_data() {
451 let original: Vec<u8> = (0..1024 * 1024).map(|i| (i % 256) as u8).collect();
452 let password = "strong-password";
453
454 let encrypted = encrypt_model(&original, password).unwrap();
455 let decrypted = decrypt_model(&encrypted, password).unwrap();
456
457 assert_eq!(original.len(), decrypted.len());
458 assert_eq!(original, decrypted);
459 }
460
461 #[test]
462 fn test_empty_data_encrypt() {
463 let original: &[u8] = &[];
464 let password = "password";
465
466 let encrypted = encrypt_model(original, password).unwrap();
467 let decrypted = decrypt_model(&encrypted, password).unwrap();
468
469 assert!(decrypted.is_empty());
470 }
471
472 #[test]
477 fn test_wrong_password_fails() {
478 let original = b"Secret model data";
479 let password = "correct-password";
480 let wrong_password = "wrong-password";
481
482 let encrypted = encrypt_model(original, password).unwrap();
483 let result = decrypt_model(&encrypted, wrong_password);
484
485 assert!(result.is_err());
486 }
487
488 #[test]
489 fn test_empty_password_rejected() {
490 let data = b"test data";
491 let result = encrypt_model(data, "");
492
493 assert!(result.is_err());
494 }
495
496 #[test]
497 fn test_corrupted_ciphertext_fails() {
498 let original = b"Test data for corruption test";
499 let password = "password";
500
501 let mut encrypted = encrypt_model(original, password).unwrap();
502
503 if encrypted.len() > HEADER_SIZE + 5 {
505 encrypted[HEADER_SIZE + 5] ^= 0xFF;
506 }
507
508 let result = decrypt_model(&encrypted, password);
509 assert!(result.is_err(), "Should detect ciphertext corruption");
510 }
511
512 #[test]
513 fn test_corrupted_tag_fails() {
514 let original = b"Test data";
515 let password = "password";
516
517 let mut encrypted = encrypt_model(original, password).unwrap();
518
519 let len = encrypted.len();
521 encrypted[len - 1] ^= 0xFF;
522
523 let result = decrypt_model(&encrypted, password);
524 assert!(result.is_err(), "Should detect tag corruption");
525 }
526
527 #[test]
528 fn test_truncated_data_fails() {
529 let original = b"Test data";
530 let password = "password";
531
532 let encrypted = encrypt_model(original, password).unwrap();
533 let truncated = &encrypted[..encrypted.len() - 10];
534
535 let result = decrypt_model(truncated, password);
536 assert!(result.is_err());
537 }
538
539 #[test]
544 fn test_is_encrypted() {
545 let original = b"Plain data";
546 let password = "password";
547
548 assert!(!is_encrypted(original));
549
550 let encrypted = encrypt_model(original, password).unwrap();
551 assert!(is_encrypted(&encrypted));
552 }
553
554 #[test]
555 fn test_get_version() {
556 let original = b"Test";
557 let password = "pwd";
558
559 let encrypted = encrypt_model(original, password).unwrap();
560 let version = get_version(&encrypted).unwrap();
561
562 assert_eq!(version, VERSION);
563 }
564
565 #[test]
566 fn test_header_serialization() {
567 let header = EncryptedHeader::new();
568 let bytes = header.to_bytes();
569 let parsed = EncryptedHeader::from_bytes(&bytes).unwrap();
570
571 assert_eq!(header.version, parsed.version);
572 assert_eq!(header.salt, parsed.salt);
573 assert_eq!(header.nonce, parsed.nonce);
574 }
575
576 #[test]
577 fn test_invalid_magic() {
578 let mut data = vec![0u8; 100];
579 data[0..8].copy_from_slice(b"NOTMAGIC");
580
581 let result = EncryptedHeader::from_bytes(&data);
582 assert!(result.is_err());
583 }
584
585 #[test]
586 fn test_unsupported_version() {
587 let mut data = vec![0u8; 100];
588 data[0..8].copy_from_slice(MAGIC);
589 data[8] = 99; let result = EncryptedHeader::from_bytes(&data);
592 assert!(result.is_err());
593 }
594
595 #[test]
600 fn test_encryption_config_default() {
601 let config = EncryptionConfig::default();
602
603 assert_eq!(config.memory_cost_kib, 65536);
604 assert_eq!(config.time_cost, 3);
605 assert_eq!(config.parallelism, 4);
606 }
607
608 #[test]
609 fn test_encrypt_with_custom_config() {
610 let original = b"Custom config test";
611 let password = "password";
612
613 let config = EncryptionConfig { memory_cost_kib: 32768, time_cost: 2, parallelism: 2 };
614
615 let encrypted = encrypt_model_with_config(original, password, &config).unwrap();
616 let decrypted = decrypt_model_with_config(&encrypted, password, &config).unwrap();
617
618 assert_eq!(original.as_slice(), decrypted.as_slice());
619 }
620
621 #[test]
626 fn test_special_characters_in_password() {
627 let original = b"Test data";
628 let password = "p@$$w0rd!#$%^&*()_+-=[]{}|;':\",./<>?";
629
630 let encrypted = encrypt_model(original, password).unwrap();
631 let decrypted = decrypt_model(&encrypted, password).unwrap();
632
633 assert_eq!(original.as_slice(), decrypted.as_slice());
634 }
635
636 #[test]
637 fn test_unicode_password() {
638 let original = b"Test data";
639 let password = "密码🔐пароль";
640
641 let encrypted = encrypt_model(original, password).unwrap();
642 let decrypted = decrypt_model(&encrypted, password).unwrap();
643
644 assert_eq!(original.as_slice(), decrypted.as_slice());
645 }
646
647 #[test]
648 fn test_very_long_password() {
649 let original = b"Test data";
650 let password = "a".repeat(10000);
651
652 let encrypted = encrypt_model(original, &password).unwrap();
653 let decrypted = decrypt_model(&encrypted, &password).unwrap();
654
655 assert_eq!(original.as_slice(), decrypted.as_slice());
656 }
657
658 #[test]
663 fn test_different_encryptions_produce_different_ciphertext() {
664 let original = b"Same data";
665 let password = "same-password";
666
667 let encrypted1 = encrypt_model(original, password).unwrap();
668 let encrypted2 = encrypt_model(original, password).unwrap();
669
670 assert_ne!(encrypted1, encrypted2);
672
673 let decrypted1 = decrypt_model(&encrypted1, password).unwrap();
675 let decrypted2 = decrypt_model(&encrypted2, password).unwrap();
676 assert_eq!(decrypted1, decrypted2);
677 }
678
679 #[test]
680 fn test_different_passwords_produce_different_ciphertext() {
681 let original = b"Same data";
682
683 let encrypted1 = encrypt_model(original, "password1").unwrap();
684 let encrypted2 = encrypt_model(original, "password2").unwrap();
685
686 assert_ne!(encrypted1, encrypted2);
687 }
688
689 #[test]
694 fn test_encryption_overhead() {
695 let original = b"Test data for size check";
696 let password = "password";
697
698 let encrypted = encrypt_model(original, password).unwrap();
699
700 let min_overhead = HEADER_SIZE + TAG_LEN;
702 assert!(encrypted.len() >= original.len() + min_overhead);
703 }
704
705 #[test]
710 fn test_single_byte_data() {
711 let original = &[0x42u8];
712 let password = "password";
713
714 let encrypted = encrypt_model(original, password).unwrap();
715 let decrypted = decrypt_model(&encrypted, password).unwrap();
716
717 assert_eq!(original.as_slice(), decrypted.as_slice());
718 }
719
720 #[test]
721 fn test_binary_data_with_nulls() {
722 let original: Vec<u8> = vec![0, 0, 0, 1, 2, 3, 0, 0, 0];
723 let password = "password";
724
725 let encrypted = encrypt_model(&original, password).unwrap();
726 let decrypted = decrypt_model(&encrypted, password).unwrap();
727
728 assert_eq!(original, decrypted);
729 }
730
731 #[test]
732 fn test_all_zeros_data() {
733 let original = vec![0u8; 1000];
734 let password = "password";
735
736 let encrypted = encrypt_model(&original, password).unwrap();
737 let decrypted = decrypt_model(&encrypted, password).unwrap();
738
739 assert_eq!(original, decrypted);
740 }
741
742 #[test]
743 fn test_all_ones_data() {
744 let original = vec![0xFFu8; 1000];
745 let password = "password";
746
747 let encrypted = encrypt_model(&original, password).unwrap();
748 let decrypted = decrypt_model(&encrypted, password).unwrap();
749
750 assert_eq!(original, decrypted);
751 }
752}