use std::{
collections::{HashMap, hash_map},
fmt,
io::Write,
num::NonZeroUsize,
path::{Path, PathBuf},
sync::{Arc, LazyLock},
time::{Duration, Instant},
};
use anyhow::Context as _;
use cache::Cache;
use const_format::formatcp;
use futures_util::StreamExt as _;
use parking_lot::{Mutex, RwLock};
use reqwest::{IntoUrl, Url, header::HeaderMap};
use crate::{
cl::SourceName,
cover,
source::{self, RateLimit},
};
mod cache;
pub(crate) const USER_AGENT: &str = formatcp!(
"{}/{} (https://github.com/desbma/sacad)",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION")
);
type ApiCache = Cache<cache::Lz4Compressor>;
type ThumbnailCache = Cache<cache::NullCompressor>;
pub(crate) struct SourceHttpClient {
client: reqwest::Client,
api_cache: Arc<ApiCache>,
thumbnail_cache: Arc<ThumbnailCache>,
rate_limit: RateLimitState,
}
#[expect(clippy::type_complexity)]
static SOURCE_CACHES: LazyLock<RwLock<HashMap<SourceName, (Arc<ApiCache>, Arc<ThumbnailCache>)>>> =
LazyLock::new(RwLock::default);
pub(crate) fn default_cache_dir() -> anyhow::Result<PathBuf> {
let dirs = directories::ProjectDirs::from("", "", env!("CARGO_PKG_NAME"))
.context("Unable to compute cache directory")?;
Ok(dirs.cache_dir().to_owned())
}
impl SourceHttpClient {
pub(crate) fn new(
source_name: SourceName,
ua: &str,
timeout: Duration,
headers: HeaderMap,
rate_limit: Option<&RateLimit>,
cache_dir: &Path,
) -> anyhow::Result<Self> {
let client = reqwest::Client::builder()
.user_agent(ua)
.timeout(timeout)
.default_headers(headers)
.build()
.context("Failed to create HTTP client")?;
let (api_cache, thumbnail_cache) = if let Some(caches) =
SOURCE_CACHES.read().get(&source_name).cloned()
{
caches
} else {
match SOURCE_CACHES.write().entry(source_name) {
hash_map::Entry::Occupied(entry) => {
let (api_cache, thumbnail_cache) = entry.get();
(Arc::clone(api_cache), Arc::clone(thumbnail_cache))
}
hash_map::Entry::Vacant(entry) => {
let api_cache = Arc::new(
Cache::new(source_name, source::RESPONSE_MAX_AGE, cache_dir).with_context(
|| format!("Failed to initialize {source_name} api cache"),
)?,
);
let thumbnail_cache = Arc::new(
ThumbnailCache::new(
format!("{source_name}_thumbs"),
cover::THUMBNAIL_MAX_AGE,
cache_dir,
)
.with_context(|| {
format!("Failed to initialize {source_name} thumbnail cache")
})?,
);
entry.insert((Arc::clone(&api_cache), Arc::clone(&thumbnail_cache)));
(api_cache, thumbnail_cache)
}
}
};
let rate_limit_state = match rate_limit {
Some(RateLimit { time, max_count }) => {
RateLimitState::Window(Mutex::new(RateLimitWindow {
start: Instant::now(),
length: *time,
count: 0,
#[expect(clippy::unwrap_used)]
limit: NonZeroUsize::new(*max_count).unwrap(),
}))
}
None => RateLimitState::None,
};
Ok(Self {
client,
api_cache,
thumbnail_cache,
rate_limit: rate_limit_state,
})
}
async fn wait(&self) {
while let Some(time_to_sleep) = self.rate_limit.wait_for() {
log::debug!(
"Waiting for {:.3}s because of rate limit",
time_to_sleep.as_secs_f64()
);
tokio::time::sleep(time_to_sleep).await;
}
}
pub(crate) async fn head(&self, url: Url) -> anyhow::Result<bool> {
log::trace!("HEAD {url}");
Ok(self
.client
.head(url.clone())
.send()
.await
.with_context(|| format!("Internal HTTP error for URL {url:?}"))?
.status()
.is_success())
}
pub(crate) async fn download_cover<U, W>(&self, url: U, mut writer: W) -> anyhow::Result<()>
where
U: IntoUrl,
W: Write,
{
self.wait().await;
log::debug!("Downloading {}...", url.as_str());
let response = self
.client
.get(url)
.timeout(Duration::from_secs(60))
.send()
.await?;
anyhow::ensure!(
response.status().is_success(),
"Request failed with status: {}",
response.status()
);
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to download chunk")?;
writer
.write_all(&chunk)
.context("Failed to write chunk to file")?;
}
Ok(())
}
pub(crate) async fn download_thumbnail<U>(&self, url: U) -> anyhow::Result<Vec<u8>>
where
U: AsRef<str> + IntoUrl + Clone,
{
Ok(self
.thumbnail_cache
.get_or_set(url.clone(), async {
let mut writer = Vec::new();
self.download_cover(url, &mut writer).await?;
Ok(writer)
})
.await?)
}
async fn get_api(&self, url: Url) -> anyhow::Result<Vec<u8>> {
log::trace!("GET {url}");
let cache_key = url.as_str().to_owned();
if let Some(cache_hit) = self
.api_cache
.get(&cache_key)
.with_context(|| format!("Cache retrieval failed for key {cache_key:?}"))?
{
log::trace!("Cache hit for key {cache_key:?}");
Ok(cache_hit)
} else {
self.wait().await;
let response = self
.client
.get(url)
.send()
.await
.with_context(|| format!("Internal HTTP error for URL {cache_key:?}"))?
.error_for_status()
.with_context(|| format!("HTTP error for URL {cache_key:?}"))?;
let data = response.bytes().await?;
self.api_cache.set(&cache_key, &data)?;
Ok(data.into())
}
}
pub(crate) async fn get_json<R>(&self, url: Url) -> anyhow::Result<R>
where
R: serde::de::DeserializeOwned,
{
let data = self.get_api(url).await?;
log::trace!("{}", String::from_utf8_lossy(&data));
let r: R = serde_json::from_slice(&data)?;
Ok(r)
}
pub(crate) async fn get_xml<R>(&self, url: Url) -> anyhow::Result<R>
where
R: fmt::Debug + serde::de::DeserializeOwned,
{
let data = self.get_api(url).await?;
let data_s = str::from_utf8(&data).context("Failed to decode string")?;
log::trace!("{data_s}");
let r: R = quick_xml::de::from_str(data_s)?;
Ok(r)
}
}
enum RateLimitState {
None,
Window(Mutex<RateLimitWindow>),
}
struct RateLimitWindow {
start: Instant,
length: Duration,
count: usize,
limit: NonZeroUsize,
}
impl RateLimitState {
fn wait_for(&self) -> Option<Duration> {
match self {
RateLimitState::None => None,
RateLimitState::Window(state) => {
let mut window_state = state.lock();
let now = Instant::now();
if now.saturating_duration_since(window_state.start) > window_state.length {
window_state.start = now;
window_state.count = 1;
None
} else if window_state.count < window_state.limit.get() {
window_state.count += 1;
None
} else {
let time_to_wait = window_state.start + window_state.length - now;
Some(time_to_wait)
}
}
}
}
}