1use std::{
22 fs,
23 io::{Cursor, Read, Seek, SeekFrom, Write},
24 path::{Path, PathBuf},
25};
26
27use anyhow::{Context, Result, anyhow, ensure};
28use argon2::Argon2;
29use byteorder::{ReadBytesExt, WriteBytesExt};
30use chacha20poly1305::{
31 XChaCha20Poly1305, XNonce,
32 aead::{Aead, KeyInit, Payload},
33};
34use dashmap::DashMap;
35use log::{debug, warn};
36use rand::prelude::*;
37use rayon::prelude::*;
38use tempfile::NamedTempFile;
39use zeroize::Zeroizing;
40
41use crate::{repo::Repo, utils::list_files};
42
43const MAGIC: &[u8; 5] = b"GITSE";
46const VERSION: u8 = 2;
47const FLAG_COMPRESSED: u8 = 1 << 0; const ENC_ALGO: u8 = 1; const SALT_LEN: usize = 16;
52const NONCE_LEN: usize = 24; const HEADER_LEN: usize = 64;
54const RESERVED_LEN: usize = HEADER_LEN - (MAGIC.len() + 1 + 1 + 1 + SALT_LEN + NONCE_LEN); const CHUNK_SIZE: usize = 65536; #[derive(Debug)]
62pub struct FileHeader {
63 version: u8,
64 flags: u8,
65 enc_algo: u8,
66 salt: [u8; SALT_LEN],
67 nonce: [u8; NONCE_LEN],
68}
69
70impl FileHeader {
71 #[must_use]
72 pub fn new(compressed: bool, salt: [u8; SALT_LEN]) -> Self {
73 let mut rng = rand::rng();
74 let mut nonce = [0u8; NONCE_LEN];
75 rng.fill_bytes(&mut nonce);
76
77 let mut flags = 0u8;
78 if compressed {
79 flags |= FLAG_COMPRESSED;
80 }
81
82 Self {
83 version: VERSION,
84 flags,
85 enc_algo: ENC_ALGO,
86 salt,
87 nonce,
88 }
89 }
90
91 pub fn write<W: Write>(&self, writer: &mut W) -> Result<()> {
93 writer.write_all(MAGIC)?;
94 writer.write_u8(self.version)?;
95 writer.write_u8(self.flags)?;
96 writer.write_u8(self.enc_algo)?;
97 writer.write_all(&self.salt)?;
98 writer.write_all(&self.nonce)?;
99 let reserved = [0u8; RESERVED_LEN];
100 writer.write_all(&reserved)?;
101 Ok(())
102 }
103
104 pub fn read<R: Read>(reader: &mut R) -> Result<Self> {
106 let mut magic_buf = [0u8; 5];
107 reader
108 .read_exact(&mut magic_buf)
109 .context("Failed to read magic")?;
110 if &magic_buf != MAGIC {
111 return Err(anyhow!("Invalid magic bytes"));
112 }
113
114 let version = reader.read_u8()?;
115 if version != VERSION {
116 return Err(anyhow!("Unsupported version: {version}"));
117 }
118
119 let flags = reader.read_u8()?;
120 let enc_algo = reader.read_u8()?;
121 if enc_algo != ENC_ALGO {
122 return Err(anyhow!("Unsupported encryption algorithm: {enc_algo}"));
123 }
124
125 let mut salt = [0u8; SALT_LEN];
126 reader.read_exact(&mut salt)?;
127 let mut nonce = [0u8; NONCE_LEN];
128 reader.read_exact(&mut nonce)?;
129 let mut reserved = [0u8; RESERVED_LEN];
130 reader.read_exact(&mut reserved)?;
131
132 Ok(Self {
133 version,
134 flags,
135 enc_algo,
136 salt,
137 nonce,
138 })
139 }
140
141 #[must_use]
142 pub const fn is_compressed(&self) -> bool {
143 (self.flags & FLAG_COMPRESSED) != 0
144 }
145}
146
147fn derive_key(password: &[u8], salt: &[u8]) -> Result<Zeroizing<[u8; 32]>> {
153 let mut key = Zeroizing::new([0u8; 32]);
154 Argon2::default()
155 .hash_password_into(password, salt, &mut *key)
156 .map_err(|e| anyhow!("Argon2 key derivation failed: {e}"))?;
157 Ok(key)
158}
159
160fn derive_nonce(base_nonce: &[u8; NONCE_LEN], chunk_idx: u64) -> XNonce {
163 let mut nonce_bytes = *base_nonce;
164 nonce_bytes[16..24].copy_from_slice(&chunk_idx.to_le_bytes());
165 XNonce::from(nonce_bytes)
166}
167
168fn atomic_write_with_metadata(original_path: &Path, temp_file: NamedTempFile) -> Result<()> {
171 if let Err(e) = copy_metadata::copy_metadata(original_path, temp_file.path()) {
174 warn!(
175 "Could not copy metadata for {}: {}",
176 original_path.display(),
177 e
178 );
179 }
180 temp_file.persist(original_path).with_context(|| {
181 format!(
182 "Failed to persist atomic write to {}",
183 original_path.display()
184 )
185 })?;
186 Ok(())
187}
188
189pub fn encrypt_file(
193 path: &Path,
194 derived_key: &[u8; 32],
195 salt: &[u8; SALT_LEN],
196 zstd: Option<u8>,
197) -> Result<()> {
198 let mut file = fs::File::open(path)?;
199
200 let mut header_bytes = [0u8; HEADER_LEN];
202 if file.read_exact(&mut header_bytes).is_ok()
203 && &header_bytes[0..5] == MAGIC
204 && header_bytes[5] == VERSION
205 {
206 warn!("File already encrypted, skipping: {}", path.display());
207 return Ok(());
208 }
209 file.seek(SeekFrom::Start(0))?; debug!("Encrypting: {}", path.display());
212
213 let header = FileHeader::new(zstd.is_some(), *salt);
215 let parent_dir = path.parent().unwrap_or_else(|| Path::new("."));
216 let mut temp_file = NamedTempFile::new_in(parent_dir)
217 .with_context(|| "Failed to create temp file".to_string())?;
218
219 header.write(&mut temp_file)?;
220
221 let cipher = XChaCha20Poly1305::new(derived_key.into());
223
224 let mut reader: Box<dyn Read> = if let Some(zstd_level) = zstd {
226 Box::new(zstd::stream::read::Encoder::new(
227 file,
228 i32::from(zstd_level),
229 )?)
230 } else {
231 Box::new(file)
232 };
233
234 let mut buffer = Zeroizing::new(vec![0u8; CHUNK_SIZE]);
236 let mut chunk_idx = 0u64;
237
238 loop {
239 let mut bytes_read = 0;
240 while bytes_read < CHUNK_SIZE {
241 let n = reader.read(&mut buffer[bytes_read..])?;
242 if n == 0 {
243 break;
244 }
245 bytes_read += n;
246 }
247
248 let is_last_chunk = bytes_read < CHUNK_SIZE;
249 let aad = if is_last_chunk { b"LAST" } else { b"MORE" };
250 let nonce = derive_nonce(&header.nonce, chunk_idx);
251
252 let payload = Payload {
253 msg: &buffer[..bytes_read],
254 aad,
255 };
256
257 let ciphertext = cipher
258 .encrypt(&nonce, payload)
259 .map_err(|e| anyhow!("Encryption failed: {e}"))?;
260
261 temp_file.write_all(&ciphertext)?;
262 chunk_idx += 1;
263
264 if is_last_chunk {
265 break;
266 }
267 }
268
269 drop(reader);
271 atomic_write_with_metadata(path, temp_file)?;
272
273 Ok(())
274}
275
276pub fn decrypt_file(path: &Path, master_key: &[u8]) -> Result<()> {
278 let key_cache = DashMap::new();
279 decrypt_file_with_cache(path, &key_cache, master_key)
280}
281
282#[allow(clippy::type_complexity)]
289pub fn decrypt_file_with_cache<S: ::std::hash::BuildHasher + Clone>(
290 path: &Path,
291 key_cache: &DashMap<[u8; SALT_LEN], Zeroizing<[u8; 32]>, S>,
292 master_key: &[u8],
293) -> Result<()> {
294 let mut file = fs::File::open(path)?;
295
296 let mut header_bytes = [0u8; HEADER_LEN];
298 if file.read_exact(&mut header_bytes).is_err() {
299 debug!(
300 "File too small to be encrypted, skipping: {}",
301 path.display()
302 );
303 return Ok(());
304 }
305 if &header_bytes[0..5] != MAGIC || header_bytes[5] != VERSION {
306 debug!(
307 "File not encrypted (no magic), skipping: {}",
308 path.display()
309 );
310 return Ok(());
311 }
312
313 debug!("Decrypting: {}", path.display());
314 let header = FileHeader::read(&mut Cursor::new(&header_bytes))
315 .with_context(|| format!("Corrupt header in {}", path.display()))?;
316
317 let derived_key = {
319 if let Some(k) = key_cache.get(&header.salt) {
320 k.clone()
321 } else {
322 let k = derive_key(master_key, &header.salt)?;
323 key_cache.insert(header.salt, k.clone());
324 k
325 }
326 };
327
328 let cipher = XChaCha20Poly1305::new(derived_key.as_ref().into());
330 let parent_dir = path.parent().unwrap_or_else(|| Path::new("."));
331 let mut temp_file = NamedTempFile::new_in(parent_dir)
332 .with_context(|| "Failed to create temp file".to_string())?;
333
334 if header.is_compressed() {
336 let mut decoder = zstd::stream::write::Decoder::new(&mut temp_file)?.auto_flush();
337 decrypt_chunks(&mut file, &mut decoder, &cipher, &header.nonce)?;
338 decoder.flush()?;
339 } else {
340 decrypt_chunks(&mut file, &mut temp_file, &cipher, &header.nonce)?;
341 }
342 drop(file);
343
344 atomic_write_with_metadata(path, temp_file)?;
346
347 Ok(())
348}
349
350fn decrypt_chunks(
353 file: &mut fs::File,
354 writer: &mut dyn Write,
355 cipher: &XChaCha20Poly1305,
356 base_nonce: &[u8; NONCE_LEN],
357) -> Result<()> {
358 let mut buffer = vec![0u8; CHUNK_SIZE + 16];
360 let mut chunk_idx = 0u64;
361 let mut last_chunk_was_final = false;
362
363 loop {
364 let mut bytes_read = 0;
365 while bytes_read < buffer.len() {
366 let n = file.read(&mut buffer[bytes_read..])?;
367 if n == 0 {
368 break;
369 }
370 bytes_read += n;
371 }
372
373 if bytes_read == 0 {
374 break; }
376
377 let is_last_chunk = bytes_read < buffer.len();
378 let aad = if is_last_chunk { b"LAST" } else { b"MORE" };
379 let nonce = derive_nonce(base_nonce, chunk_idx);
380
381 let payload = chacha20poly1305::aead::Payload {
382 msg: &buffer[..bytes_read],
383 aad,
384 };
385
386 let plaintext = Zeroizing::new(cipher.decrypt(&nonce, payload).map_err(|e| {
387 anyhow!("Decryption failed (wrong password, corrupt, or tampered data): {e}")
388 })?);
389
390 writer.write_all(&plaintext)?;
391 chunk_idx += 1;
392
393 if is_last_chunk {
394 last_chunk_was_final = true;
395 break;
396 }
397 }
398
399 if !last_chunk_was_final {
400 return Err(anyhow!(
401 "File truncation detected! The ciphertext is incomplete."
402 ));
403 }
404
405 Ok(())
406}
407
408pub fn encrypt_repo(repo: &'static Repo, paths: Vec<PathBuf>) -> Result<()> {
416 let key = repo.get_key();
417 assert!(!key.is_empty(), "Key must not be empty");
418
419 let target_files = if paths.is_empty() {
420 list_files(repo.conf.crypt_list.iter(), repo.path())
421 } else {
422 list_files(paths, repo.path())
423 };
424 ensure!(!target_files.is_empty(), "No file to encrypt");
425
426 let mut salt = [0u8; SALT_LEN];
428 rand::rng().fill_bytes(&mut salt);
429
430 let derived_key = derive_key(key.as_bytes(), &salt)?;
432
433 target_files.par_iter().try_for_each(|f| -> Result<()> {
435 encrypt_file(
436 f,
437 &derived_key,
438 &salt,
439 repo.conf.use_zstd.then_some(repo.conf.zstd_level),
440 )
441 .with_context(|| format!("Failed to encrypt {}", f.display()))
442 })?;
443
444 Ok(())
445}
446
447pub fn decrypt_repo(repo: &'static Repo, paths: Vec<PathBuf>) -> Result<()> {
453 let key = repo.get_key();
454 assert!(!key.is_empty(), "Master key must not be empty");
455
456 let target_files = if paths.is_empty() {
457 list_files(repo.conf.crypt_list.iter(), repo.path())
458 } else {
459 list_files(paths, repo.path())
460 };
461 ensure!(!target_files.is_empty(), "No file to decrypt");
462
463 target_files
465 .par_iter()
466 .filter(|p| p.is_file())
467 .try_for_each(|f| -> Result<()> {
468 decrypt_file(f, key.as_bytes())
469 .with_context(|| format!("Failed to decrypt {}", f.display()))
470 })?;
471
472 Ok(())
473}
474
475#[cfg(test)]
476mod tests {
477 use std::io::{Cursor, Read, Write};
478
479 use tempfile::{NamedTempFile, TempPath};
480
481 use super::*;
482
483 fn get_test_key_and_salt() -> ([u8; 32], [u8; SALT_LEN]) {
486 let password = b"super_secret_password";
487 let mut salt = [0u8; SALT_LEN];
488 rand::rng().fill_bytes(&mut salt);
489 let derived = derive_key(password, &salt).unwrap();
490 let mut key = [0u8; 32];
491 key.copy_from_slice(&*derived);
492 (key, salt)
493 }
494
495 fn create_temp_file(content: &[u8]) -> TempPath {
496 let mut file = NamedTempFile::new().unwrap();
497 file.write_all(content).unwrap();
498 file.flush().unwrap();
499 file.into_temp_path()
500 }
501
502 #[test]
505 fn test_header_serialization() {
506 let salt = [0xAB; SALT_LEN];
507 let header = FileHeader::new(true, salt);
508
509 let mut buf = Vec::new();
510 header.write(&mut buf).unwrap();
511
512 assert_eq!(buf.len(), HEADER_LEN);
513
514 let mut cursor = Cursor::new(buf);
515 let decoded = FileHeader::read(&mut cursor).unwrap();
516
517 assert_eq!(decoded.version, VERSION);
518 assert_eq!(decoded.flags, FLAG_COMPRESSED);
519 assert_eq!(decoded.enc_algo, ENC_ALGO);
520 assert_eq!(decoded.salt, salt);
521 assert_eq!(decoded.nonce, header.nonce);
522 assert!(decoded.is_compressed());
523 }
524
525 #[test]
526 fn test_nonce_derivation() {
527 let base_nonce = [0u8; NONCE_LEN];
528
529 let nonce0 = derive_nonce(&base_nonce, 0);
531 assert_eq!(nonce0.as_slice(), &[0u8; NONCE_LEN]);
532
533 let nonce1 = derive_nonce(&base_nonce, 1);
535 let mut expected1 = [0u8; NONCE_LEN];
536 expected1[16] = 1;
537 assert_eq!(nonce1.as_slice(), &expected1);
538
539 let nonce256 = derive_nonce(&base_nonce, 256);
542 let mut expected256 = [0u8; NONCE_LEN];
543 expected256[17] = 1;
544 assert_eq!(nonce256.as_slice(), &expected256);
545 }
546
547 #[test]
548 fn test_encrypt_decrypt_basic_no_compression() {
549 let plaintext = b"Hello, World! This is a test without compression.";
550 let path = create_temp_file(plaintext);
551
552 let (key, salt) = get_test_key_and_salt();
553 let master_key = b"super_secret_password";
554
555 encrypt_file(&path, &key, &salt, None).unwrap();
557
558 let mut encrypted_content = Vec::new();
560 fs::File::open(&path)
561 .unwrap()
562 .read_to_end(&mut encrypted_content)
563 .unwrap();
564 assert_ne!(encrypted_content, plaintext);
565 assert_eq!(&encrypted_content[0..5], MAGIC);
566
567 decrypt_file(&path, master_key).unwrap();
569
570 let mut decrypted_content = Vec::new();
572 fs::File::open(path)
573 .unwrap()
574 .read_to_end(&mut decrypted_content)
575 .unwrap();
576 assert_eq!(decrypted_content, plaintext);
577 }
578
579 #[test]
580 fn test_encrypt_decrypt_with_compression() {
581 let plaintext = b"A".repeat(10000);
583 let path = create_temp_file(&plaintext);
584
585 let (key, salt) = get_test_key_and_salt();
586 let master_key = b"super_secret_password";
587
588 encrypt_file(&path, &key, &salt, Some(3)).unwrap();
590
591 let encrypted_meta = fs::metadata(&path).unwrap();
594 assert!(encrypted_meta.len() < 5000);
595
596 decrypt_file(&path, master_key).unwrap();
598
599 let mut decrypted_content = Vec::new();
601 fs::File::open(path)
602 .unwrap()
603 .read_to_end(&mut decrypted_content)
604 .unwrap();
605 assert_eq!(decrypted_content, plaintext);
606 }
607
608 #[test]
609 #[allow(clippy::cast_possible_truncation)]
610 #[allow(clippy::cast_sign_loss)]
611 fn test_chunked_encryption_large_file() {
612 let plaintext = {
614 let mut data = Vec::with_capacity(100_000);
615 for i in 0..100_000 {
616 data.push((i % 256) as u8);
617 }
618 data
619 };
620
621 let path = create_temp_file(&plaintext);
622
623 let (key, salt) = get_test_key_and_salt();
624 let master_key = b"super_secret_password";
625
626 encrypt_file(&path, &key, &salt, None).unwrap();
628
629 decrypt_file(&path, master_key).unwrap();
631
632 let mut decrypted_content = Vec::new();
634 fs::File::open(path)
635 .unwrap()
636 .read_to_end(&mut decrypted_content)
637 .unwrap();
638 assert_eq!(decrypted_content, plaintext);
639 }
640
641 #[test]
642 fn test_tamper_resistance() {
643 let plaintext = b"Sensitive data that should not be tampered with.";
644 let path = create_temp_file(plaintext);
645
646 let (key, salt) = get_test_key_and_salt();
647 let master_key = b"super_secret_password";
648
649 encrypt_file(&path, &key, &salt, None).unwrap();
651
652 let mut encrypted_content = Vec::new();
654 let mut f = fs::OpenOptions::new()
655 .read(true)
656 .write(true)
657 .open(&path)
658 .unwrap();
659 f.read_to_end(&mut encrypted_content).unwrap();
660
661 encrypted_content[HEADER_LEN + 5] ^= 0xFF;
663
664 f.seek(std::io::SeekFrom::Start(0)).unwrap();
665 f.write_all(&encrypted_content).unwrap();
666 drop(f);
667
668 let result = decrypt_file(&path, master_key);
670
671 assert!(result.is_err());
672 assert!(
673 result
674 .unwrap_err()
675 .to_string()
676 .contains("Decryption failed")
677 );
678 }
679
680 #[cfg(unix)]
681 #[test]
682 fn test_metadata_preservation() {
683 use std::os::unix::fs::PermissionsExt;
684
685 let plaintext = b"Executable script content";
686 let file = create_temp_file(plaintext);
687 let path = file.path();
688
689 let mut perms = fs::metadata(path).unwrap().permissions();
691 perms.set_mode(0o755);
692 fs::set_permissions(path, perms).unwrap();
693
694 let (key, salt) = get_test_key_and_salt();
695 let master_key = b"super_secret_password";
696
697 encrypt_file(path, &key, &salt, false, 0).unwrap();
699
700 let encrypted_perms = fs::metadata(path).unwrap().permissions();
702 assert_eq!(encrypted_perms.mode() & 0o777, 0o755);
703
704 let key_cache = DashMap::new();
706 decrypt_file(path, master_key).unwrap();
707
708 let decrypted_perms = fs::metadata(path).unwrap().permissions();
710 assert_eq!(decrypted_perms.mode() & 0o777, 0o755);
711 }
712}