use bytes::Bytes;
use hashiverse_lib::protocol::payload::payload::CacheRequestTokenV1;
use hashiverse_lib::protocol::peer::Peer;
use hashiverse_lib::tools::buckets::BucketLocation;
use hashiverse_lib::tools::server_id::ServerId;
use hashiverse_lib::tools::time::{TimeMillis, MILLIS_IN_MINUTE};
use hashiverse_lib::tools::tools::leading_agreement_bits_xor;
use hashiverse_lib::tools::types::Id;
use moka::sync::Cache;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::server::post_bundle_caching_shared::{CachedBundle, GetCacheResult, CACHE_HIT_THRESHOLD, CACHE_LOCATION_TTI, CACHE_REQUEST_TOKEN_TTL_DURATION, CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS};
const POST_BUNDLE_PLACEHOLDER_WEIGHT: u32 = 4 * 1024 * 1024;
struct CachedPostBundleLocationEntry {
bundles: HashMap<Id, CachedBundle>,
hit_count: u32,
}
impl CachedPostBundleLocationEntry {
fn placeholder() -> Self {
Self { bundles: HashMap::new(), hit_count: 0 }
}
fn weight(&self) -> u32 {
let total: u32 = self.bundles.values().map(|b| b.bytes.len() as u32).sum();
if total == 0 { POST_BUNDLE_PLACEHOLDER_WEIGHT } else { total }
}
}
pub struct PostBundleCache {
max_originators_per_location: usize,
bundles: Cache<Id, Arc<Mutex<CachedPostBundleLocationEntry>>>,
inflight: Cache<Id, ()>,
}
impl PostBundleCache {
pub fn new(max_originators_per_location: usize, max_bytes: u64) -> Self {
let bundles = Cache::builder()
.weigher(|_key: &Id, entry: &Arc<Mutex<CachedPostBundleLocationEntry>>| {
entry.lock().map(|e| e.weight()).unwrap_or(POST_BUNDLE_PLACEHOLDER_WEIGHT)
})
.max_capacity(max_bytes)
.time_to_idle(CACHE_LOCATION_TTI)
.build();
let inflight = Cache::builder()
.time_to_live(CACHE_REQUEST_TOKEN_TTL_DURATION)
.build();
Self { max_originators_per_location, bundles, inflight }
}
pub fn on_get(
&self,
bucket_location: &BucketLocation,
already_retrieved_peer_ids: &[Id],
peer_self: &Peer,
server_id: &ServerId,
now: TimeMillis,
) -> GetCacheResult {
let location_id = bucket_location.location_id;
let entry_arc = self.bundles.get_with(location_id, || Arc::new(Mutex::new(CachedPostBundleLocationEntry::placeholder())));
let (cached_items, already_cached_peer_ids, should_issue_token) = {
let mut entry = entry_arc.lock().unwrap();
entry.hit_count += 1;
let already_retrieved_set: std::collections::HashSet<Id> = already_retrieved_peer_ids.iter().copied().collect();
let cached_items: Vec<Bytes> = entry.bundles
.iter()
.filter(|(originator_id, bundle)| !already_retrieved_set.contains(originator_id) && !bundle.is_stale(now))
.map(|(_, bundle)| bundle.bytes.clone())
.collect();
let already_cached_peer_ids: Vec<Id> = entry.bundles.keys().copied().collect();
let should_issue_token = entry.hit_count >= CACHE_HIT_THRESHOLD && !self.inflight.contains_key(&location_id);
(cached_items, already_cached_peer_ids, should_issue_token)
};
let cache_request_token = if should_issue_token {
self.inflight.insert(location_id, ());
let expires_at = now + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS;
Some(CacheRequestTokenV1::new(peer_self.clone(), bucket_location.clone(), expires_at, already_cached_peer_ids, &server_id.keys.signature_key))
} else {
None
};
GetCacheResult { cached_items, cache_request_token }
}
pub fn on_upload(
&self,
location_id: Id,
originator_peer_id: Id,
bundle_bytes: Bytes,
server_time: TimeMillis,
is_sealed: bool,
) -> bool {
let entry_arc = match self.bundles.get(&location_id) {
Some(e) => e,
None => return false, };
let mut entry = entry_arc.lock().unwrap();
let expires_at = if is_sealed { None } else { Some(server_time + MILLIS_IN_MINUTE.const_mul(5)) };
let bundle = CachedBundle { bytes: bundle_bytes, expires_at };
entry.bundles.insert(originator_peer_id, bundle);
while entry.bundles.len() > self.max_originators_per_location {
let evict_key = entry.bundles
.iter()
.min_by(|(id_a, bundle_a), (id_b, bundle_b)| {
let distance_a = leading_agreement_bits_xor(id_a.as_ref(), location_id.as_ref());
let distance_b = leading_agreement_bits_xor(id_b.as_ref(), location_id.as_ref());
distance_a.cmp(&distance_b).then_with(|| {
let expires_a = bundle_a.expires_at.unwrap_or(TimeMillis(i64::MAX));
let expires_b = bundle_b.expires_at.unwrap_or(TimeMillis(i64::MAX));
expires_a.cmp(&expires_b)
})
})
.map(|(id, _)| *id);
if let Some(k) = evict_key {
entry.bundles.remove(&k);
}
}
entry.bundles.contains_key(&originator_peer_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hashiverse_lib::tools::buckets::{BucketLocation, BucketType, BUCKET_DURATIONS};
use hashiverse_lib::tools::server_id::ServerId;
use hashiverse_lib::tools::time::TimeMillis;
use hashiverse_lib::tools::time_provider::time_provider::RealTimeProvider;
use hashiverse_lib::tools::pow_generator::single_threaded_pow_generator::SingleThreadedPowGenerator;
use hashiverse_lib::tools::types::{Id, Pow};
async fn make_test_server_and_peer() -> anyhow::Result<(ServerId, hashiverse_lib::protocol::peer::Peer)> {
let time_provider = RealTimeProvider;
let pow_generator = SingleThreadedPowGenerator::new();
let server_id = ServerId::new("own_pow", &time_provider, Pow(0), true, &pow_generator).await?;
let peer = server_id.to_peer(&time_provider)?;
Ok((server_id, peer))
}
fn make_test_bucket_location() -> BucketLocation {
BucketLocation::new(BucketType::User, Id::random(), BUCKET_DURATIONS[0], TimeMillis(1_000_000)).unwrap()
}
#[tokio::test]
async fn test_below_threshold_no_token() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024);
let bucket_location = make_test_bucket_location();
let now = TimeMillis(1_000_000);
for _ in 0..(CACHE_HIT_THRESHOLD - 1) {
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result.cache_request_token.is_none());
assert!(result.cached_items.is_empty());
}
Ok(())
}
#[tokio::test]
async fn test_at_threshold_token_issued_then_deduplicated() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024);
let bucket_location = make_test_bucket_location();
let now = TimeMillis(1_000_000);
for _ in 0..(CACHE_HIT_THRESHOLD - 1) {
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result.cache_request_token.is_none());
}
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result.cache_request_token.is_some());
let result2 = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result2.cache_request_token.is_none());
Ok(())
}
#[tokio::test]
async fn test_upload_and_retrieval() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024);
let bucket_location = make_test_bucket_location();
let location_id = bucket_location.location_id;
let now = TimeMillis(1_000_000);
let originator_id = Id::random();
let bundle_bytes = Bytes::from_static(b"test_bundle");
cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
let accepted = cache.on_upload(location_id, originator_id, bundle_bytes.clone(), now, false);
assert!(accepted);
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert_eq!(result.cached_items, vec![bundle_bytes]);
Ok(())
}
#[tokio::test]
async fn test_already_retrieved_filtered() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024);
let bucket_location = make_test_bucket_location();
let location_id = bucket_location.location_id;
let now = TimeMillis(1_000_000);
let originator_id = Id::random();
cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
cache.on_upload(location_id, originator_id, Bytes::from_static(b"bundle"), now, false);
let result = cache.on_get(&bucket_location, &[originator_id], &peer_self, &server_id, now);
assert!(result.cached_items.is_empty());
Ok(())
}
#[tokio::test]
async fn test_upload_returns_false_when_not_in_cache() -> anyhow::Result<()> {
let cache = PostBundleCache::new(5, 64 * 1024 * 1024);
let location_id = Id::random();
let originator_id = Id::random();
let accepted = cache.on_upload(location_id, originator_id, Bytes::from_static(b"bundle"), TimeMillis(1_000_000), false);
assert!(!accepted);
Ok(())
}
}