use bytes::Bytes;
use fs2::FileExt;
use sha2::{Digest, Sha256};
use std::{
collections::HashMap,
io,
path::{Path, PathBuf},
sync::{Arc, LazyLock},
};
use tokio::{
fs,
sync::{Mutex, Notify, OnceCell},
};
static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
dirs::cache_dir()
.unwrap_or_else(|| std::env::temp_dir())
.join("tyml")
});
static INIT_DONE: OnceCell<()> = OnceCell::const_new();
static INFLIGHT: LazyLock<Mutex<HashMap<String, Arc<Notify>>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub async fn get_cached_file(url: &str) -> io::Result<PathBuf> {
ensure_cache_ready().await?;
let cache_path = CACHE_DIR.join(format!("{}.tyml", hash_url(url)));
let notify = {
let mut map = INFLIGHT.lock().await;
map.entry(url.to_string())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
};
let lock_path = cache_path.with_extension("tyml.lock");
let _file_lock = tokio::task::spawn_blocking(move || acquire_file_lock(&lock_path))
.await
.expect("join failure")?;
match try_download(url).await {
Ok(bytes) => {
write_atomically(&cache_path, &bytes).await?;
}
Err(net_err) => {
if cache_path.exists() {
finish_waiters(url, ¬ify).await;
return Ok(cache_path);
} else {
finish_waiters(url, ¬ify).await;
return Err(net_err);
}
}
}
finish_waiters(url, ¬ify).await;
Ok(cache_path)
}
async fn try_download(url: &str) -> io::Result<Bytes> {
let resp = reqwest::Client::new()
.get(url)
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Err(io::Error::other(format!("HTTP error: {}", resp.status())));
}
resp.bytes().await.map_err(io::Error::other)
}
async fn ensure_cache_ready() -> io::Result<()> {
INIT_DONE
.get_or_try_init(|| async {
if !CACHE_DIR.exists() {
fs::create_dir_all(&*CACHE_DIR).await?;
}
Ok(())
})
.await
.map(|_| ())
}
async fn write_atomically(path: &Path, bytes: &Bytes) -> io::Result<()> {
let tmp = path.with_extension("tyml.tmp");
fs::write(&tmp, bytes).await?;
fs::rename(&tmp, path).await?;
Ok(())
}
async fn finish_waiters(url: &str, notify: &Notify) {
notify.notify_waiters();
INFLIGHT.lock().await.remove(url);
}
fn acquire_file_lock(lock_path: &Path) -> io::Result<std::fs::File> {
if let Some(p) = lock_path.parent() {
std::fs::create_dir_all(p)?;
}
let f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(lock_path)?;
f.lock_exclusive()?;
Ok(f)
}
fn hash_url(u: &str) -> String {
let mut h = Sha256::new();
h.update(u.as_bytes());
format!("{:x}", h.finalize())
}