use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::mpsc,
};
use log::{debug, warn};
use memmap2::Mmap;
use rkyv::rancor::Error as RkyvError;
use crate::{
crypt::{FILE_ID_LEN, SALT_LEN},
utils::atomic_write,
};
const CACHE_FILENAME: &str = "git-simple-encrypt-salt-cache";
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct CachedEntry {
pub salt: [u8; SALT_LEN],
pub file_id: [u8; FILE_ID_LEN],
}
fn cache_path(repo_path: &Path) -> PathBuf {
repo_path.join(".git").join(CACHE_FILENAME)
}
pub struct SaltCacheReader {
mmap: Option<Mmap>,
}
impl SaltCacheReader {
pub fn load(repo_path: &Path) -> Self {
let path = cache_path(repo_path);
let mmap = if path.exists() {
match std::fs::File::open(&path) {
Ok(file) => match unsafe { Mmap::map(&file) } {
Ok(mmap) => {
match rkyv::access::<rkyv::Archived<HashMap<Vec<u8>, CachedEntry>>, RkyvError>(
&mmap,
) {
Ok(_) => {
debug!("Loaded salt cache from {}", path.display());
Some(mmap)
}
Err(e) => {
warn!("Corrupted salt cache at {}: {e}", path.display());
None
}
}
}
Err(e) => {
warn!("Failed to mmap salt cache at {}: {e}", path.display());
None
}
},
Err(e) => {
warn!("Failed to open salt cache at {}: {e}", path.display());
None
}
}
} else {
debug!("Salt cache not found at {}", path.display());
None
};
Self { mmap }
}
pub fn get(&self, key: &[u8]) -> Option<CachedEntry> {
let mmap = self.mmap.as_ref()?;
let archived = unsafe {
rkyv::access_unchecked::<rkyv::Archived<HashMap<Vec<u8>, CachedEntry>>>(mmap.as_ref())
};
let entry = archived.get(key)?;
Some(CachedEntry {
salt: entry.salt,
file_id: entry.file_id,
})
}
}
pub struct SaltCacheSender {
tx: mpsc::Sender<(Vec<u8>, CachedEntry)>,
}
impl SaltCacheSender {
pub fn insert(&self, key: &[u8], entry: CachedEntry) {
let _ = self.tx.send((key.to_vec(), entry));
}
}
pub struct SaltCacheSaver {
rx: mpsc::Receiver<(Vec<u8>, CachedEntry)>,
repo_path: PathBuf,
}
impl SaltCacheSaver {
pub fn save(self) {
let Self { rx, repo_path } = self;
let mut entries: HashMap<Vec<u8>, CachedEntry> = rx.into_iter().collect();
if entries.is_empty() {
debug!("No cache entries to save");
return;
}
let path = cache_path(&repo_path);
if path.exists()
&& let Ok(existing_bytes) = std::fs::read(&path)
&& let Ok(existing) =
rkyv::from_bytes::<HashMap<Vec<u8>, CachedEntry>, RkyvError>(&existing_bytes)
{
for (k, v) in existing {
entries.entry(k).or_insert(v);
}
}
match rkyv::to_bytes::<RkyvError>(&entries) {
Ok(bytes) => {
if let Err(e) = atomic_write(&path, bytes.as_slice()) {
warn!("Failed to save salt cache to {}: {e}", path.display());
} else {
debug!(
"Saved salt cache with {} entries to {}",
entries.len(),
path.display()
);
}
}
Err(e) => {
warn!("Failed to serialize salt cache: {e}");
}
}
}
}
pub fn create_writer(repo_path: &Path) -> (SaltCacheSender, SaltCacheSaver) {
let (tx, rx) = mpsc::channel();
(
SaltCacheSender { tx },
SaltCacheSaver {
rx,
repo_path: repo_path.to_path_buf(),
},
)
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
use super::*;
fn make_entry(salt_byte: u8, file_id_byte: u8) -> CachedEntry {
CachedEntry {
salt: [salt_byte; SALT_LEN],
file_id: [file_id_byte; FILE_ID_LEN],
}
}
#[test]
fn test_reader_get_from_wrong_path() {
let dir = TempDir::new().unwrap();
let reader = SaltCacheReader::load(dir.path());
assert_eq!(reader.get(b"test.txt"), None);
}
#[test]
fn test_roundtrip_via_sender_and_reader() {
let dir = TempDir::new().unwrap();
let repo = dir.path();
std::fs::create_dir_all(repo.join(".git")).unwrap();
let entry1 = make_entry(0x11, 0x22);
let entry2 = make_entry(0x33, 0x44);
{
let (sender, saver) = create_writer(repo);
sender.insert(b"file1.txt", entry1.clone());
sender.insert(b"sub/file2.txt", entry2.clone());
drop(sender);
saver.save();
}
let reader = SaltCacheReader::load(repo);
assert_eq!(reader.get(b"file1.txt"), Some(entry1));
assert_eq!(reader.get(b"sub/file2.txt"), Some(entry2));
assert_eq!(reader.get(b"nonexistent.txt"), None);
}
#[test]
fn test_load_corrupted_file() {
let dir = TempDir::new().unwrap();
let repo = dir.path();
std::fs::create_dir_all(repo.join(".git")).unwrap();
let path = cache_path(repo);
std::fs::write(&path, b"not valid rkyv data").unwrap();
let reader = SaltCacheReader::load(repo);
assert_eq!(reader.get(b"test.txt"), None);
}
#[test]
fn test_overwrite_entry() {
let dir = TempDir::new().unwrap();
let repo = dir.path();
std::fs::create_dir_all(repo.join(".git")).unwrap();
let entry1 = make_entry(0x11, 0x22);
let entry2 = make_entry(0x33, 0x44);
{
let (sender, saver) = create_writer(repo);
sender.insert(b"test.txt", entry1);
sender.insert(b"test.txt", entry2.clone());
drop(sender);
saver.save();
}
let reader = SaltCacheReader::load(repo);
assert_eq!(reader.get(b"test.txt"), Some(entry2));
}
#[test]
fn test_relative_path_key_persistence() {
let dir = TempDir::new().unwrap();
let repo = dir.path();
std::fs::create_dir_all(repo.join(".git")).unwrap();
let entry = make_entry(0x55, 0x66);
{
let (sender, saver) = create_writer(repo);
sender.insert(b"subdir/file.txt", entry.clone());
drop(sender);
saver.save();
}
let reader = SaltCacheReader::load(repo);
assert_eq!(reader.get(b"subdir/file.txt"), Some(entry));
}
#[test]
fn test_merge_with_existing() {
let dir = TempDir::new().unwrap();
let repo = dir.path();
std::fs::create_dir_all(repo.join(".git")).unwrap();
let entry_a = make_entry(0xAA, 0xBB);
let entry_b = make_entry(0xCC, 0xDD);
{
let (sender, saver) = create_writer(repo);
sender.insert(b"existing.txt", entry_a.clone());
drop(sender);
saver.save();
}
{
let (sender, saver) = create_writer(repo);
sender.insert(b"new.txt", entry_b.clone());
drop(sender);
saver.save();
}
let reader = SaltCacheReader::load(repo);
assert_eq!(reader.get(b"existing.txt"), Some(entry_a));
assert_eq!(reader.get(b"new.txt"), Some(entry_b));
}
}