use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use kellnr_common::normalized_name::NormalizedName;
use kellnr_common::version::Version;
use tracing::{trace, warn};
use crate::DbProvider;
pub struct DownloadCounter {
db: Arc<dyn DbProvider>,
counts: Mutex<HashMap<(NormalizedName, Version), u64>>,
cached_counts: Mutex<HashMap<(NormalizedName, Version), u64>>,
flush_interval: u64,
}
impl DownloadCounter {
pub fn new(db: Arc<dyn DbProvider>, flush_interval: u64) -> Self {
Self {
db,
counts: Mutex::new(HashMap::new()),
cached_counts: Mutex::new(HashMap::new()),
flush_interval,
}
}
pub async fn increment_and_maybe_flush(&self, name: NormalizedName, version: Version) {
if self.flush_interval == 0 {
if let Err(e) = self.db.increase_download_counter(&name, &version).await {
warn!("Failed to increment download counter for {name} {version}: {e}");
}
} else {
self.increment(name, &version);
}
}
pub async fn increment_cached_and_maybe_flush(&self, name: NormalizedName, version: Version) {
if self.flush_interval == 0 {
if let Err(e) = self
.db
.increase_cached_download_counter(&name, &version)
.await
{
warn!("Failed to increment cached download counter for {name} {version}: {e}");
}
} else {
self.increment_cached(name, &version);
}
}
fn increment(&self, name: NormalizedName, version: &Version) {
let mut counts = self.counts.lock().expect("download counter lock poisoned");
*counts.entry((name, version.clone())).or_insert(0) += 1;
}
fn increment_cached(&self, name: NormalizedName, version: &Version) {
let mut counts = self
.cached_counts
.lock()
.expect("cached download counter lock poisoned");
*counts.entry((name, version.clone())).or_insert(0) += 1;
}
pub async fn flush(&self) {
let counts = {
let mut lock = self.counts.lock().expect("download counter lock poisoned");
std::mem::take(&mut *lock)
};
let cached_counts = {
let mut lock = self
.cached_counts
.lock()
.expect("cached download counter lock poisoned");
std::mem::take(&mut *lock)
};
let total_kellnr = counts.len();
let total_cached = cached_counts.len();
if total_kellnr == 0 && total_cached == 0 {
return;
}
for ((name, version), count) in counts {
if let Err(e) = self
.db
.increase_download_counter_by(&name, &version, count)
.await
{
warn!("Failed to flush download counter for {name} {version} (count={count}): {e}");
let mut lock = self.counts.lock().expect("download counter lock poisoned");
*lock.entry((name, version)).or_insert(0) += count;
}
}
for ((name, version), count) in cached_counts {
if let Err(e) = self
.db
.increase_cached_download_counter_by(&name, &version, count)
.await
{
warn!(
"Failed to flush cached download counter for {name} {version} (count={count}): {e}"
);
let mut lock = self
.cached_counts
.lock()
.expect("cached download counter lock poisoned");
*lock.entry((name, version)).or_insert(0) += count;
}
}
trace!(
"Flushed download counters: {total_kellnr} kellnr crates, {total_cached} cached crates"
);
}
}