use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::config;
use crate::crypto;
use crate::crypto::wal;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MtimeEntry {
db_mt: u64,
wal_mt: u64,
path: String,
}
#[derive(Debug, Clone)]
struct CacheEntry {
db_mtime: u64,
wal_mtime: u64,
decrypted_path: PathBuf,
}
pub struct DbCache {
db_dir: PathBuf,
cache_dir: PathBuf,
all_keys: HashMap<String, String>, inner: Arc<Mutex<HashMap<String, CacheEntry>>>,
}
impl DbCache {
pub async fn new(
db_dir: PathBuf,
all_keys: HashMap<String, String>,
) -> Result<Self> {
let cache_dir = config::cache_dir();
tokio::fs::create_dir_all(&cache_dir).await?;
let inner: HashMap<String, CacheEntry> = HashMap::new();
let cache = DbCache {
db_dir,
cache_dir,
all_keys,
inner: Arc::new(Mutex::new(inner)),
};
cache.load_persistent().await;
Ok(cache)
}
fn cache_file_path(&self, rel_key: &str) -> PathBuf {
let hash = format!("{:x}", md5::compute(rel_key.as_bytes()));
let short = &hash[..12];
self.cache_dir.join(format!("{}.db", short))
}
async fn load_persistent(&self) {
let mtime_file = config::mtime_file();
let content = match tokio::fs::read_to_string(&mtime_file).await {
Ok(c) => c,
Err(_) => return,
};
let saved: HashMap<String, MtimeEntry> = match serde_json::from_str(&content) {
Ok(v) => v,
Err(_) => return,
};
let mut inner = self.inner.lock().await;
let mut reused = 0usize;
for (rel_key, entry) in &saved {
let dec_path = PathBuf::from(&entry.path);
if !dec_path.exists() {
continue;
}
let db_path = self.db_dir.join(rel_key.replace('\\', std::path::MAIN_SEPARATOR_STR).replace('/', std::path::MAIN_SEPARATOR_STR));
let wal_path_str = format!("{}-wal", db_path.display());
let wal_path = Path::new(&wal_path_str);
let db_mt = mtime_nanos(&db_path);
let wal_mt = if wal_path.exists() { mtime_nanos(wal_path) } else { 0 };
if db_mt == entry.db_mt && wal_mt == entry.wal_mt {
inner.insert(rel_key.clone(), CacheEntry {
db_mtime: db_mt,
wal_mtime: wal_mt,
decrypted_path: dec_path,
});
reused += 1;
}
}
if reused > 0 {
eprintln!("[cache] 复用 {} 个已解密 DB", reused);
}
}
async fn save_persistent(&self) {
let mtime_file = config::mtime_file();
let inner = self.inner.lock().await;
let data: HashMap<String, MtimeEntry> = inner.iter().map(|(k, v)| {
(k.clone(), MtimeEntry {
db_mt: v.db_mtime,
wal_mt: v.wal_mtime,
path: v.decrypted_path.to_string_lossy().into_owned(),
})
}).collect();
drop(inner);
if let Ok(json) = serde_json::to_string_pretty(&data) {
let _ = tokio::fs::write(&mtime_file, json).await;
}
}
pub async fn get(&self, rel_key: &str) -> Result<Option<PathBuf>> {
let enc_key_hex = match self.all_keys.get(rel_key) {
Some(k) => k.clone(),
None => return Ok(None),
};
let db_path = self.db_dir.join(
rel_key.replace('\\', std::path::MAIN_SEPARATOR_STR)
.replace('/', std::path::MAIN_SEPARATOR_STR)
);
if !db_path.exists() {
return Ok(None);
}
let wal_path_str = format!("{}-wal", db_path.display());
let wal_path = Path::new(&wal_path_str).to_path_buf();
let db_mt = mtime_nanos(&db_path);
let wal_mt = if wal_path.exists() { mtime_nanos(&wal_path) } else { 0 };
{
let inner = self.inner.lock().await;
if let Some(entry) = inner.get(rel_key) {
if entry.db_mtime == db_mt
&& entry.wal_mtime == wal_mt
&& entry.decrypted_path.exists()
{
return Ok(Some(entry.decrypted_path.clone()));
}
}
}
let out_path = self.cache_file_path(rel_key);
let enc_key_bytes = hex_to_32bytes(&enc_key_hex)
.with_context(|| format!("密钥格式错误: {}", rel_key))?;
let t0 = std::time::Instant::now();
let db_path2 = db_path.clone();
let out_path2 = out_path.clone();
let key_copy = enc_key_bytes;
tokio::task::spawn_blocking(move || {
crypto::full_decrypt(&db_path2, &out_path2, &key_copy)
}).await??;
if wal_path.exists() {
let out_path3 = out_path.clone();
let wal_path3 = wal_path.clone();
let key_copy2 = enc_key_bytes;
tokio::task::spawn_blocking(move || {
wal::apply_wal(&wal_path3, &out_path3, &key_copy2)
}).await??;
}
let elapsed_ms = t0.elapsed().as_millis();
eprintln!("[cache] 解密 {} ({}ms)", rel_key, elapsed_ms);
{
let mut inner = self.inner.lock().await;
inner.insert(rel_key.to_string(), CacheEntry {
db_mtime: db_mt,
wal_mtime: wal_mt,
decrypted_path: out_path.clone(),
});
}
self.save_persistent().await;
Ok(Some(out_path))
}
}
fn mtime_nanos(path: &Path) -> u64 {
std::fs::metadata(path)
.and_then(|m| m.modified())
.map(|t| t.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_nanos() as u64)
.unwrap_or(0)
}
fn hex_to_32bytes(s: &str) -> Result<[u8; 32]> {
if s.len() != 64 {
anyhow::bail!("密钥 hex 长度应为 64,实际为 {}", s.len());
}
let mut out = [0u8; 32];
for i in 0..32 {
out[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
.with_context(|| format!("非法 hex 字符 at {}", i * 2))?;
}
Ok(out)
}