use bytes::Bytes;
use hashiverse_lib::protocol::payload::payload::CacheRequestTokenV1;
use hashiverse_lib::protocol::peer::Peer;
use hashiverse_lib::tools::buckets::BucketLocation;
use hashiverse_lib::tools::server_id::ServerId;
use hashiverse_lib::tools::time::{TimeMillis, MILLIS_IN_MINUTE};
use hashiverse_lib::tools::time_provider::moka_clock::TimeProviderMokaClock;
use hashiverse_lib::tools::time_provider::time_provider::TimeProvider;
use hashiverse_lib::tools::tools::leading_agreement_bits_xor;
use hashiverse_lib::tools::types::Id;
use moka::sync::Cache;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::server::post_bundle_caching_shared::{CachedBundle, GetCacheResult, CACHE_HIT_THRESHOLD, CACHE_LOCATION_TTI, CACHE_REQUEST_TOKEN_TTL_DURATION, CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS};
const POST_BUNDLE_PLACEHOLDER_WEIGHT: u32 = 4 * 1024 * 1024;
struct CachedPostBundleLocationEntry {
bundles: HashMap<Id, CachedBundle>,
hit_count: u32,
}
impl CachedPostBundleLocationEntry {
fn placeholder() -> Self {
Self { bundles: HashMap::new(), hit_count: 0 }
}
fn weight(&self) -> u32 {
let total: u32 = self.bundles.values().map(|b| b.bytes.len() as u32).sum();
if total == 0 { POST_BUNDLE_PLACEHOLDER_WEIGHT } else { total }
}
}
pub struct PostBundleCache {
max_originators_per_location: usize,
bundles: Cache<Id, Arc<Mutex<CachedPostBundleLocationEntry>>>,
inflight: Cache<Id, ()>,
}
impl PostBundleCache {
pub fn new(max_originators_per_location: usize, max_bytes: u64, time_provider: Arc<dyn TimeProvider>) -> Self {
let clock = Arc::new(TimeProviderMokaClock::new(time_provider));
let bundles = Cache::builder()
.weigher(|_key: &Id, entry: &Arc<Mutex<CachedPostBundleLocationEntry>>| {
entry.lock().map(|e| e.weight()).unwrap_or(POST_BUNDLE_PLACEHOLDER_WEIGHT)
})
.max_capacity(max_bytes)
.time_to_idle(CACHE_LOCATION_TTI)
.external_clock(clock.clone())
.build();
let inflight = Cache::builder()
.time_to_live(CACHE_REQUEST_TOKEN_TTL_DURATION)
.external_clock(clock)
.build();
Self { max_originators_per_location, bundles, inflight }
}
pub fn on_get(
&self,
bucket_location: &BucketLocation,
already_retrieved_peer_ids: &[Id],
peer_self: &Peer,
server_id: &ServerId,
now: TimeMillis,
) -> GetCacheResult {
let location_id = bucket_location.location_id;
let entry_arc = self.bundles.get_with(location_id, || Arc::new(Mutex::new(CachedPostBundleLocationEntry::placeholder())));
let (cached_items, already_cached_peer_ids, should_issue_token) = {
let mut entry = entry_arc.lock().unwrap();
entry.hit_count += 1;
let already_retrieved_set: std::collections::HashSet<Id> = already_retrieved_peer_ids.iter().copied().collect();
let cached_items: Vec<Bytes> = entry.bundles
.iter()
.filter(|(originator_id, bundle)| !already_retrieved_set.contains(originator_id) && !bundle.is_stale(now))
.map(|(_, bundle)| bundle.bytes.clone())
.collect();
let already_cached_peer_ids: Vec<Id> = entry.bundles.keys().copied().collect();
let should_issue_token = entry.hit_count >= CACHE_HIT_THRESHOLD && !self.inflight.contains_key(&location_id);
(cached_items, already_cached_peer_ids, should_issue_token)
};
let cache_request_token = if should_issue_token {
self.inflight.insert(location_id, ());
let expires_at = now + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS;
Some(CacheRequestTokenV1::new(peer_self.clone(), bucket_location.clone(), expires_at, already_cached_peer_ids, &server_id.keys.signature_key))
} else {
None
};
GetCacheResult { cached_items, cache_request_token }
}
pub fn on_upload(
&self,
location_id: Id,
originator_peer_id: Id,
bundle_bytes: Bytes,
server_time: TimeMillis,
is_sealed: bool,
) -> bool {
let entry_arc = match self.bundles.get(&location_id) {
Some(e) => e,
None => return false, };
let mut entry = entry_arc.lock().unwrap();
let expires_at = if is_sealed { None } else { Some(server_time + MILLIS_IN_MINUTE.const_mul(5)) };
let bundle = CachedBundle { bytes: bundle_bytes, expires_at };
entry.bundles.insert(originator_peer_id, bundle);
while entry.bundles.len() > self.max_originators_per_location {
let evict_key = entry.bundles
.iter()
.min_by(|(id_a, bundle_a), (id_b, bundle_b)| {
let distance_a = leading_agreement_bits_xor(id_a.as_ref(), location_id.as_ref());
let distance_b = leading_agreement_bits_xor(id_b.as_ref(), location_id.as_ref());
distance_a.cmp(&distance_b).then_with(|| {
let expires_a = bundle_a.expires_at.unwrap_or(TimeMillis(i64::MAX));
let expires_b = bundle_b.expires_at.unwrap_or(TimeMillis(i64::MAX));
expires_a.cmp(&expires_b)
})
})
.map(|(id, _)| *id);
if let Some(k) = evict_key {
entry.bundles.remove(&k);
}
}
entry.bundles.contains_key(&originator_peer_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hashiverse_lib::tools::buckets::{BucketLocation, BucketType, BUCKET_DURATIONS};
use hashiverse_lib::tools::server_id::ServerId;
use hashiverse_lib::tools::time::TimeMillis;
use hashiverse_lib::tools::time_provider::time_provider::RealTimeProvider;
use hashiverse_lib::tools::pow_generator::single_threaded_pow_generator::SingleThreadedPowGenerator;
use hashiverse_lib::tools::types::{Id, Pow};
async fn make_test_server_and_peer() -> anyhow::Result<(ServerId, hashiverse_lib::protocol::peer::Peer)> {
let time_provider = RealTimeProvider;
let pow_generator = SingleThreadedPowGenerator::new();
let server_id = ServerId::new("own_pow", &time_provider, Pow(0), true, &pow_generator).await?;
let peer = server_id.to_peer(&time_provider)?;
Ok((server_id, peer))
}
fn make_test_bucket_location() -> BucketLocation {
BucketLocation::new(BucketType::User, Id::random(), BUCKET_DURATIONS[0], TimeMillis(1_000_000)).unwrap()
}
#[tokio::test]
async fn test_below_threshold_no_token() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
let bucket_location = make_test_bucket_location();
let now = TimeMillis(1_000_000);
for _ in 0..(CACHE_HIT_THRESHOLD - 1) {
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result.cache_request_token.is_none());
assert!(result.cached_items.is_empty());
}
Ok(())
}
#[tokio::test]
async fn test_at_threshold_token_issued_then_deduplicated() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
let bucket_location = make_test_bucket_location();
let now = TimeMillis(1_000_000);
for _ in 0..(CACHE_HIT_THRESHOLD - 1) {
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result.cache_request_token.is_none());
}
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result.cache_request_token.is_some());
let result2 = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert!(result2.cache_request_token.is_none());
Ok(())
}
#[tokio::test]
async fn test_upload_and_retrieval() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
let bucket_location = make_test_bucket_location();
let location_id = bucket_location.location_id;
let now = TimeMillis(1_000_000);
let originator_id = Id::random();
let bundle_bytes = Bytes::from_static(b"test_bundle");
cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
let accepted = cache.on_upload(location_id, originator_id, bundle_bytes.clone(), now, false);
assert!(accepted);
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
assert_eq!(result.cached_items, vec![bundle_bytes]);
Ok(())
}
#[tokio::test]
async fn test_already_retrieved_filtered() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
let bucket_location = make_test_bucket_location();
let location_id = bucket_location.location_id;
let now = TimeMillis(1_000_000);
let originator_id = Id::random();
cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
cache.on_upload(location_id, originator_id, Bytes::from_static(b"bundle"), now, false);
let result = cache.on_get(&bucket_location, &[originator_id], &peer_self, &server_id, now);
assert!(result.cached_items.is_empty());
Ok(())
}
#[tokio::test]
async fn test_upload_returns_false_when_not_in_cache() -> anyhow::Result<()> {
let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
let location_id = Id::random();
let originator_id = Id::random();
let accepted = cache.on_upload(location_id, originator_id, Bytes::from_static(b"bundle"), TimeMillis(1_000_000), false);
assert!(!accepted);
Ok(())
}
#[tokio::test]
async fn test_overflow_keeps_closest_originators() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(3, 64 * 1024 * 1024, Arc::new(RealTimeProvider)); let bucket_location = make_test_bucket_location();
let location_id = bucket_location.location_id;
let now = TimeMillis(1_000_000);
cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
let originator_at = |flip_bit: usize| -> Id {
let mut bytes = location_id.0;
bytes[flip_bit / 8] ^= 1 << (7 - (flip_bit % 8));
Id(bytes)
};
for &p in &[20usize, 40, 60, 80, 100] {
let bytes = Bytes::from(format!("bundle-agreement-{}", p));
cache.on_upload(location_id, originator_at(p), bytes, now, true);
}
let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
let cached: std::collections::HashSet<Vec<u8>> = result.cached_items.iter().map(|b| b.to_vec()).collect();
assert_eq!(3, cached.len(), "cache must keep exactly max_originators_per_location entries");
for &p in &[60usize, 80, 100] {
assert!(cached.contains(format!("bundle-agreement-{}", p).as_bytes()), "closest originator (agreement {}) must be kept", p);
}
for &p in &[20usize, 40] {
assert!(!cached.contains(format!("bundle-agreement-{}", p).as_bytes()), "furthest originator (agreement {}) must be evicted", p);
}
Ok(())
}
#[tokio::test]
async fn test_cache_request_token_expiry() -> anyhow::Result<()> {
let (server_id, peer_self) = make_test_server_and_peer().await?;
let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
let bucket_location = make_test_bucket_location();
let now = TimeMillis(1_000_000);
let mut token = None;
for _ in 0..CACHE_HIT_THRESHOLD {
token = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now).cache_request_token.or(token);
}
let token = token.expect("server issues a token at the hit threshold");
assert!(!token.is_expired(now), "token must be valid at issue time");
assert!(!token.is_expired(TimeMillis(now.0 + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS.0 - 1)), "token must be valid just before its TTL elapses");
assert!(token.is_expired(TimeMillis(now.0 + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS.0 + 1)), "token must be expired once its TTL has elapsed");
Ok(())
}
}