use std::{
fs,
io::{Cursor, Read, Seek, SeekFrom, Write},
path::{Path, PathBuf},
};
use anyhow::{Context, Result, anyhow, ensure};
use argon2::Argon2;
use byteorder::{ReadBytesExt, WriteBytesExt};
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, KeyInit, Payload},
};
use dashmap::DashMap;
use log::{debug, warn};
use rand::prelude::*;
use rayon::prelude::*;
use tempfile::NamedTempFile;
use zeroize::Zeroizing;
use crate::{repo::Repo, utils::list_files};
const MAGIC: &[u8; 5] = b"GITSE";
const VERSION: u8 = 2;
const FLAG_COMPRESSED: u8 = 1 << 0; const ENC_ALGO: u8 = 1;
const SALT_LEN: usize = 16;
const NONCE_LEN: usize = 24; const HEADER_LEN: usize = 64;
const RESERVED_LEN: usize = HEADER_LEN - (MAGIC.len() + 1 + 1 + 1 + SALT_LEN + NONCE_LEN);
const CHUNK_SIZE: usize = 65536;
#[derive(Debug)]
pub struct FileHeader {
version: u8,
flags: u8,
enc_algo: u8,
salt: [u8; SALT_LEN],
nonce: [u8; NONCE_LEN],
}
impl FileHeader {
#[must_use]
pub fn new(compressed: bool, salt: [u8; SALT_LEN]) -> Self {
let mut rng = rand::rng();
let mut nonce = [0u8; NONCE_LEN];
rng.fill_bytes(&mut nonce);
let mut flags = 0u8;
if compressed {
flags |= FLAG_COMPRESSED;
}
Self {
version: VERSION,
flags,
enc_algo: ENC_ALGO,
salt,
nonce,
}
}
pub fn write<W: Write>(&self, writer: &mut W) -> Result<()> {
writer.write_all(MAGIC)?;
writer.write_u8(self.version)?;
writer.write_u8(self.flags)?;
writer.write_u8(self.enc_algo)?;
writer.write_all(&self.salt)?;
writer.write_all(&self.nonce)?;
let reserved = [0u8; RESERVED_LEN];
writer.write_all(&reserved)?;
Ok(())
}
pub fn read<R: Read>(reader: &mut R) -> Result<Self> {
let mut magic_buf = [0u8; 5];
reader
.read_exact(&mut magic_buf)
.context("Failed to read magic")?;
if &magic_buf != MAGIC {
return Err(anyhow!("Invalid magic bytes"));
}
let version = reader.read_u8()?;
if version != VERSION {
return Err(anyhow!("Unsupported version: {version}"));
}
let flags = reader.read_u8()?;
let enc_algo = reader.read_u8()?;
if enc_algo != ENC_ALGO {
return Err(anyhow!("Unsupported encryption algorithm: {enc_algo}"));
}
let mut salt = [0u8; SALT_LEN];
reader.read_exact(&mut salt)?;
let mut nonce = [0u8; NONCE_LEN];
reader.read_exact(&mut nonce)?;
let mut reserved = [0u8; RESERVED_LEN];
reader.read_exact(&mut reserved)?;
Ok(Self {
version,
flags,
enc_algo,
salt,
nonce,
})
}
#[must_use]
pub const fn is_compressed(&self) -> bool {
(self.flags & FLAG_COMPRESSED) != 0
}
}
fn derive_key(password: &[u8], salt: &[u8]) -> Result<Zeroizing<[u8; 32]>> {
let mut key = Zeroizing::new([0u8; 32]);
Argon2::default()
.hash_password_into(password, salt, &mut *key)
.map_err(|e| anyhow!("Argon2 key derivation failed: {e}"))?;
Ok(key)
}
fn derive_nonce(base_nonce: &[u8; NONCE_LEN], chunk_idx: u64) -> XNonce {
let mut nonce_bytes = *base_nonce;
nonce_bytes[16..24].copy_from_slice(&chunk_idx.to_le_bytes());
XNonce::from(nonce_bytes)
}
fn atomic_write_with_metadata(original_path: &Path, temp_file: NamedTempFile) -> Result<()> {
if let Err(e) = copy_metadata::copy_metadata(original_path, temp_file.path()) {
warn!(
"Could not copy metadata for {}: {}",
original_path.display(),
e
);
}
temp_file.persist(original_path).with_context(|| {
format!(
"Failed to persist atomic write to {}",
original_path.display()
)
})?;
Ok(())
}
pub fn encrypt_file(
path: &Path,
derived_key: &[u8; 32],
salt: &[u8; SALT_LEN],
zstd: Option<u8>,
) -> Result<()> {
let mut file = fs::File::open(path)?;
let mut header_bytes = [0u8; HEADER_LEN];
if file.read_exact(&mut header_bytes).is_ok()
&& &header_bytes[0..5] == MAGIC
&& header_bytes[5] == VERSION
{
warn!("File already encrypted, skipping: {}", path.display());
return Ok(());
}
file.seek(SeekFrom::Start(0))?;
debug!("Encrypting: {}", path.display());
let header = FileHeader::new(zstd.is_some(), *salt);
let parent_dir = path.parent().unwrap_or_else(|| Path::new("."));
let mut temp_file = NamedTempFile::new_in(parent_dir)
.with_context(|| "Failed to create temp file".to_string())?;
header.write(&mut temp_file)?;
let cipher = XChaCha20Poly1305::new(derived_key.into());
let mut reader: Box<dyn Read> = if let Some(zstd_level) = zstd {
Box::new(zstd::stream::read::Encoder::new(
file,
i32::from(zstd_level),
)?)
} else {
Box::new(file)
};
let mut buffer = Zeroizing::new(vec![0u8; CHUNK_SIZE]);
let mut chunk_idx = 0u64;
loop {
let mut bytes_read = 0;
while bytes_read < CHUNK_SIZE {
let n = reader.read(&mut buffer[bytes_read..])?;
if n == 0 {
break;
}
bytes_read += n;
}
let is_last_chunk = bytes_read < CHUNK_SIZE;
let aad = if is_last_chunk { b"LAST" } else { b"MORE" };
let nonce = derive_nonce(&header.nonce, chunk_idx);
let payload = Payload {
msg: &buffer[..bytes_read],
aad,
};
let ciphertext = cipher
.encrypt(&nonce, payload)
.map_err(|e| anyhow!("Encryption failed: {e}"))?;
temp_file.write_all(&ciphertext)?;
chunk_idx += 1;
if is_last_chunk {
break;
}
}
drop(reader);
atomic_write_with_metadata(path, temp_file)?;
Ok(())
}
pub fn decrypt_file(path: &Path, master_key: &[u8]) -> Result<()> {
let key_cache = DashMap::new();
decrypt_file_with_cache(path, &key_cache, master_key)
}
#[allow(clippy::type_complexity)]
pub fn decrypt_file_with_cache<S: ::std::hash::BuildHasher + Clone>(
path: &Path,
key_cache: &DashMap<[u8; SALT_LEN], Zeroizing<[u8; 32]>, S>,
master_key: &[u8],
) -> Result<()> {
let mut file = fs::File::open(path)?;
let mut header_bytes = [0u8; HEADER_LEN];
if file.read_exact(&mut header_bytes).is_err() {
debug!(
"File too small to be encrypted, skipping: {}",
path.display()
);
return Ok(());
}
if &header_bytes[0..5] != MAGIC || header_bytes[5] != VERSION {
debug!(
"File not encrypted (no magic), skipping: {}",
path.display()
);
return Ok(());
}
debug!("Decrypting: {}", path.display());
let header = FileHeader::read(&mut Cursor::new(&header_bytes))
.with_context(|| format!("Corrupt header in {}", path.display()))?;
let derived_key = {
if let Some(k) = key_cache.get(&header.salt) {
k.clone()
} else {
let k = derive_key(master_key, &header.salt)?;
key_cache.insert(header.salt, k.clone());
k
}
};
let cipher = XChaCha20Poly1305::new(derived_key.as_ref().into());
let parent_dir = path.parent().unwrap_or_else(|| Path::new("."));
let mut temp_file = NamedTempFile::new_in(parent_dir)
.with_context(|| "Failed to create temp file".to_string())?;
if header.is_compressed() {
let mut decoder = zstd::stream::write::Decoder::new(&mut temp_file)?.auto_flush();
decrypt_chunks(&mut file, &mut decoder, &cipher, &header.nonce)?;
decoder.flush()?;
} else {
decrypt_chunks(&mut file, &mut temp_file, &cipher, &header.nonce)?;
}
drop(file);
atomic_write_with_metadata(path, temp_file)?;
Ok(())
}
fn decrypt_chunks(
file: &mut fs::File,
writer: &mut dyn Write,
cipher: &XChaCha20Poly1305,
base_nonce: &[u8; NONCE_LEN],
) -> Result<()> {
let mut buffer = vec![0u8; CHUNK_SIZE + 16];
let mut chunk_idx = 0u64;
let mut last_chunk_was_final = false;
loop {
let mut bytes_read = 0;
while bytes_read < buffer.len() {
let n = file.read(&mut buffer[bytes_read..])?;
if n == 0 {
break;
}
bytes_read += n;
}
if bytes_read == 0 {
break; }
let is_last_chunk = bytes_read < buffer.len();
let aad = if is_last_chunk { b"LAST" } else { b"MORE" };
let nonce = derive_nonce(base_nonce, chunk_idx);
let payload = chacha20poly1305::aead::Payload {
msg: &buffer[..bytes_read],
aad,
};
let plaintext = Zeroizing::new(cipher.decrypt(&nonce, payload).map_err(|e| {
anyhow!("Decryption failed (wrong password, corrupt, or tampered data): {e}")
})?);
writer.write_all(&plaintext)?;
chunk_idx += 1;
if is_last_chunk {
last_chunk_was_final = true;
break;
}
}
if !last_chunk_was_final {
return Err(anyhow!(
"File truncation detected! The ciphertext is incomplete."
));
}
Ok(())
}
pub fn encrypt_repo(repo: &'static Repo, paths: Vec<PathBuf>) -> Result<()> {
let key = repo.get_key();
assert!(!key.is_empty(), "Key must not be empty");
let target_files = if paths.is_empty() {
list_files(repo.conf.crypt_list.iter(), repo.path())
} else {
list_files(paths, repo.path())
};
ensure!(!target_files.is_empty(), "No file to encrypt");
let mut salt = [0u8; SALT_LEN];
rand::rng().fill_bytes(&mut salt);
let derived_key = derive_key(key.as_bytes(), &salt)?;
target_files.par_iter().try_for_each(|f| -> Result<()> {
encrypt_file(
f,
&derived_key,
&salt,
repo.conf.use_zstd.then_some(repo.conf.zstd_level),
)
.with_context(|| format!("Failed to encrypt {}", f.display()))
})?;
Ok(())
}
pub fn decrypt_repo(repo: &'static Repo, paths: Vec<PathBuf>) -> Result<()> {
let key = repo.get_key();
assert!(!key.is_empty(), "Master key must not be empty");
let target_files = if paths.is_empty() {
list_files(repo.conf.crypt_list.iter(), repo.path())
} else {
list_files(paths, repo.path())
};
ensure!(!target_files.is_empty(), "No file to decrypt");
target_files
.par_iter()
.filter(|p| p.is_file())
.try_for_each(|f| -> Result<()> {
decrypt_file(f, key.as_bytes())
.with_context(|| format!("Failed to decrypt {}", f.display()))
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Read, Write};
use tempfile::{NamedTempFile, TempPath};
use super::*;
fn get_test_key_and_salt() -> ([u8; 32], [u8; SALT_LEN]) {
let password = b"super_secret_password";
let mut salt = [0u8; SALT_LEN];
rand::rng().fill_bytes(&mut salt);
let derived = derive_key(password, &salt).unwrap();
let mut key = [0u8; 32];
key.copy_from_slice(&*derived);
(key, salt)
}
fn create_temp_file(content: &[u8]) -> TempPath {
let mut file = NamedTempFile::new().unwrap();
file.write_all(content).unwrap();
file.flush().unwrap();
file.into_temp_path()
}
#[test]
fn test_header_serialization() {
let salt = [0xAB; SALT_LEN];
let header = FileHeader::new(true, salt);
let mut buf = Vec::new();
header.write(&mut buf).unwrap();
assert_eq!(buf.len(), HEADER_LEN);
let mut cursor = Cursor::new(buf);
let decoded = FileHeader::read(&mut cursor).unwrap();
assert_eq!(decoded.version, VERSION);
assert_eq!(decoded.flags, FLAG_COMPRESSED);
assert_eq!(decoded.enc_algo, ENC_ALGO);
assert_eq!(decoded.salt, salt);
assert_eq!(decoded.nonce, header.nonce);
assert!(decoded.is_compressed());
}
#[test]
fn test_nonce_derivation() {
let base_nonce = [0u8; NONCE_LEN];
let nonce0 = derive_nonce(&base_nonce, 0);
assert_eq!(nonce0.as_slice(), &[0u8; NONCE_LEN]);
let nonce1 = derive_nonce(&base_nonce, 1);
let mut expected1 = [0u8; NONCE_LEN];
expected1[16] = 1;
assert_eq!(nonce1.as_slice(), &expected1);
let nonce256 = derive_nonce(&base_nonce, 256);
let mut expected256 = [0u8; NONCE_LEN];
expected256[17] = 1;
assert_eq!(nonce256.as_slice(), &expected256);
}
#[test]
fn test_encrypt_decrypt_basic_no_compression() {
let plaintext = b"Hello, World! This is a test without compression.";
let path = create_temp_file(plaintext);
let (key, salt) = get_test_key_and_salt();
let master_key = b"super_secret_password";
encrypt_file(&path, &key, &salt, None).unwrap();
let mut encrypted_content = Vec::new();
fs::File::open(&path)
.unwrap()
.read_to_end(&mut encrypted_content)
.unwrap();
assert_ne!(encrypted_content, plaintext);
assert_eq!(&encrypted_content[0..5], MAGIC);
decrypt_file(&path, master_key).unwrap();
let mut decrypted_content = Vec::new();
fs::File::open(path)
.unwrap()
.read_to_end(&mut decrypted_content)
.unwrap();
assert_eq!(decrypted_content, plaintext);
}
#[test]
fn test_encrypt_decrypt_with_compression() {
let plaintext = b"A".repeat(10000);
let path = create_temp_file(&plaintext);
let (key, salt) = get_test_key_and_salt();
let master_key = b"super_secret_password";
encrypt_file(&path, &key, &salt, Some(3)).unwrap();
let encrypted_meta = fs::metadata(&path).unwrap();
assert!(encrypted_meta.len() < 5000);
decrypt_file(&path, master_key).unwrap();
let mut decrypted_content = Vec::new();
fs::File::open(path)
.unwrap()
.read_to_end(&mut decrypted_content)
.unwrap();
assert_eq!(decrypted_content, plaintext);
}
#[test]
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
fn test_chunked_encryption_large_file() {
let plaintext = {
let mut data = Vec::with_capacity(100_000);
for i in 0..100_000 {
data.push((i % 256) as u8);
}
data
};
let path = create_temp_file(&plaintext);
let (key, salt) = get_test_key_and_salt();
let master_key = b"super_secret_password";
encrypt_file(&path, &key, &salt, None).unwrap();
decrypt_file(&path, master_key).unwrap();
let mut decrypted_content = Vec::new();
fs::File::open(path)
.unwrap()
.read_to_end(&mut decrypted_content)
.unwrap();
assert_eq!(decrypted_content, plaintext);
}
#[test]
fn test_tamper_resistance() {
let plaintext = b"Sensitive data that should not be tampered with.";
let path = create_temp_file(plaintext);
let (key, salt) = get_test_key_and_salt();
let master_key = b"super_secret_password";
encrypt_file(&path, &key, &salt, None).unwrap();
let mut encrypted_content = Vec::new();
let mut f = fs::OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
f.read_to_end(&mut encrypted_content).unwrap();
encrypted_content[HEADER_LEN + 5] ^= 0xFF;
f.seek(std::io::SeekFrom::Start(0)).unwrap();
f.write_all(&encrypted_content).unwrap();
drop(f);
let result = decrypt_file(&path, master_key);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Decryption failed")
);
}
#[cfg(unix)]
#[test]
fn test_metadata_preservation() {
use std::os::unix::fs::PermissionsExt;
let plaintext = b"Executable script content";
let file = create_temp_file(plaintext);
let path = file.path();
let mut perms = fs::metadata(path).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(path, perms).unwrap();
let (key, salt) = get_test_key_and_salt();
let master_key = b"super_secret_password";
encrypt_file(path, &key, &salt, false, 0).unwrap();
let encrypted_perms = fs::metadata(path).unwrap().permissions();
assert_eq!(encrypted_perms.mode() & 0o777, 0o755);
let key_cache = DashMap::new();
decrypt_file(path, master_key).unwrap();
let decrypted_perms = fs::metadata(path).unwrap().permissions();
assert_eq!(decrypted_perms.mode() & 0o777, 0o755);
}
}