use std::collections::HashMap;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use thiserror::Error;
#[derive(Debug, Error, Eq, PartialEq)]
pub enum HintStoreError {
#[error("hint store over capacity ({max_bytes} bytes)")]
OverCapacity {
max_bytes: u64,
},
#[error("hint TTL must be greater than zero")]
ZeroTtl,
#[error("hint payload is empty")]
EmptyPayload,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Hint {
pub peer_idx: u32,
pub payload: Vec<u8>,
pub deadline: Instant,
}
impl Hint {
#[must_use]
fn weight(&self) -> u64 {
u64::try_from(self.payload.len()).unwrap_or(u64::MAX)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct HintStoreStats {
pub hint_count: usize,
pub bytes: u64,
pub max_bytes: u64,
pub expired_total: u64,
pub rejected_over_capacity_total: u64,
}
#[derive(Debug)]
pub struct HintStore {
inner: Mutex<Inner>,
}
#[derive(Debug)]
struct Inner {
by_peer: HashMap<u32, Vec<Hint>>,
bytes: u64,
max_bytes: u64,
expired_total: u64,
rejected_over_capacity_total: u64,
}
impl HintStore {
#[must_use]
pub fn new(max_bytes: u64) -> Self {
Self {
inner: Mutex::new(Inner {
by_peer: HashMap::new(),
bytes: 0,
max_bytes,
expired_total: 0,
rejected_over_capacity_total: 0,
}),
}
}
pub fn enqueue(
&self,
peer_idx: u32,
payload: Vec<u8>,
ttl: Duration,
) -> Result<(), HintStoreError> {
if ttl.is_zero() {
return Err(HintStoreError::ZeroTtl);
}
if payload.is_empty() {
return Err(HintStoreError::EmptyPayload);
}
let weight = u64::try_from(payload.len()).unwrap_or(u64::MAX);
let mut inner = self.inner.lock();
if inner.max_bytes > 0 && inner.bytes.saturating_add(weight) > inner.max_bytes {
inner.rejected_over_capacity_total =
inner.rejected_over_capacity_total.saturating_add(1);
return Err(HintStoreError::OverCapacity {
max_bytes: inner.max_bytes,
});
}
let deadline = Instant::now() + ttl;
inner.by_peer.entry(peer_idx).or_default().push(Hint {
peer_idx,
payload,
deadline,
});
inner.bytes = inner.bytes.saturating_add(weight);
Ok(())
}
pub fn take_for(&self, peer_idx: u32) -> Vec<Hint> {
let now = Instant::now();
let mut inner = self.inner.lock();
let Some(queue) = inner.by_peer.remove(&peer_idx) else {
return Vec::new();
};
let mut out = Vec::with_capacity(queue.len());
for h in queue {
if h.deadline <= now {
let w = h.weight();
inner.bytes = inner.bytes.saturating_sub(w);
inner.expired_total = inner.expired_total.saturating_add(1);
continue;
}
inner.bytes = inner.bytes.saturating_sub(h.weight());
out.push(h);
}
out
}
pub fn expire_now(&self, now: Instant) -> usize {
let mut inner = self.inner.lock();
let mut dropped = 0usize;
let mut empty_keys: Vec<u32> = Vec::new();
for (k, queue) in &mut inner.by_peer {
let before = queue.len();
queue.retain(|h| h.deadline > now);
let after = queue.len();
let removed = before - after;
if removed > 0 {
dropped += removed;
if after == 0 {
empty_keys.push(*k);
}
}
}
let mut new_bytes: u64 = 0;
for queue in inner.by_peer.values() {
for h in queue {
new_bytes = new_bytes.saturating_add(h.weight());
}
}
inner.bytes = new_bytes;
inner.expired_total = inner.expired_total.saturating_add(dropped as u64);
for k in empty_keys {
inner.by_peer.remove(&k);
}
dropped
}
#[must_use]
pub fn total_len(&self) -> usize {
let inner = self.inner.lock();
inner.by_peer.values().map(Vec::len).sum()
}
#[must_use]
pub fn len_for(&self, peer_idx: u32) -> usize {
let inner = self.inner.lock();
inner.by_peer.get(&peer_idx).map_or(0, Vec::len)
}
#[must_use]
pub fn stats(&self) -> HintStoreStats {
let inner = self.inner.lock();
HintStoreStats {
hint_count: inner.by_peer.values().map(Vec::len).sum(),
bytes: inner.bytes,
max_bytes: inner.max_bytes,
expired_total: inner.expired_total,
rejected_over_capacity_total: inner.rejected_over_capacity_total,
}
}
#[must_use]
pub fn peers_with_hints(&self) -> Vec<u32> {
let inner = self.inner.lock();
inner
.by_peer
.iter()
.filter_map(|(k, v)| if v.is_empty() { None } else { Some(*k) })
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn payload(b: u8, n: usize) -> Vec<u8> {
vec![b; n]
}
#[test]
fn enqueue_and_take_round_trip() {
let store = HintStore::new(1024);
store
.enqueue(3, payload(b'a', 4), Duration::from_secs(60))
.unwrap();
store
.enqueue(3, payload(b'b', 4), Duration::from_secs(60))
.unwrap();
store
.enqueue(7, payload(b'c', 4), Duration::from_secs(60))
.unwrap();
assert_eq!(store.total_len(), 3);
let drained = store.take_for(3);
assert_eq!(drained.len(), 2);
assert_eq!(drained[0].payload, payload(b'a', 4));
assert_eq!(drained[1].payload, payload(b'b', 4));
assert_eq!(store.len_for(3), 0);
assert_eq!(store.len_for(7), 1);
assert_eq!(store.total_len(), 1);
}
#[test]
fn enqueue_rejects_over_capacity() {
let store = HintStore::new(8);
store
.enqueue(0, payload(b'x', 6), Duration::from_secs(60))
.unwrap();
let err = store
.enqueue(0, payload(b'y', 4), Duration::from_secs(60))
.unwrap_err();
assert_eq!(err, HintStoreError::OverCapacity { max_bytes: 8 });
assert_eq!(store.stats().bytes, 6);
assert_eq!(store.stats().rejected_over_capacity_total, 1);
let drained = store.take_for(0);
assert_eq!(drained.len(), 1);
store
.enqueue(0, payload(b'y', 4), Duration::from_secs(60))
.unwrap();
}
#[test]
fn expire_now_drops_old_hints() {
let store = HintStore::new(64);
store
.enqueue(1, payload(b'a', 3), Duration::from_millis(1))
.unwrap();
store
.enqueue(1, payload(b'b', 3), Duration::from_secs(60))
.unwrap();
std::thread::sleep(Duration::from_millis(5));
let now = Instant::now();
let dropped = store.expire_now(now);
assert_eq!(dropped, 1);
assert_eq!(store.len_for(1), 1);
let stats = store.stats();
assert_eq!(stats.expired_total, 1);
assert_eq!(stats.bytes, 3);
let drained = store.take_for(1);
assert_eq!(drained[0].payload, payload(b'b', 3));
}
#[test]
fn take_for_skips_already_expired() {
let store = HintStore::new(64);
store
.enqueue(2, payload(b'a', 3), Duration::from_millis(1))
.unwrap();
store
.enqueue(2, payload(b'b', 3), Duration::from_secs(60))
.unwrap();
std::thread::sleep(Duration::from_millis(5));
let drained = store.take_for(2);
assert_eq!(drained.len(), 1);
assert_eq!(drained[0].payload, payload(b'b', 3));
assert_eq!(store.stats().expired_total, 1);
}
#[test]
fn enqueue_rejects_zero_ttl_and_empty_payload() {
let store = HintStore::new(64);
let err = store
.enqueue(0, payload(b'x', 1), Duration::from_secs(0))
.unwrap_err();
assert_eq!(err, HintStoreError::ZeroTtl);
let err = store
.enqueue(0, Vec::new(), Duration::from_secs(60))
.unwrap_err();
assert_eq!(err, HintStoreError::EmptyPayload);
assert_eq!(store.total_len(), 0);
}
#[test]
fn mixed_peer_queues_are_independent() {
let store = HintStore::new(0); store
.enqueue(0, payload(b'a', 1), Duration::from_secs(60))
.unwrap();
store
.enqueue(1, payload(b'b', 1), Duration::from_secs(60))
.unwrap();
store
.enqueue(2, payload(b'c', 1), Duration::from_secs(60))
.unwrap();
assert_eq!(store.total_len(), 3);
let mut peers = store.peers_with_hints();
peers.sort_unstable();
assert_eq!(peers, vec![0, 1, 2]);
let drained = store.take_for(1);
assert_eq!(drained.len(), 1);
assert_eq!(drained[0].payload, payload(b'b', 1));
assert_eq!(store.len_for(0), 1);
assert_eq!(store.len_for(1), 0);
assert_eq!(store.len_for(2), 1);
}
#[test]
fn empty_max_bytes_means_unbounded() {
let store = HintStore::new(0);
for _ in 0..1024 {
store
.enqueue(0, payload(b'x', 1024), Duration::from_secs(60))
.unwrap();
}
assert_eq!(store.total_len(), 1024);
}
#[test]
fn expire_now_no_op_when_nothing_old() {
let store = HintStore::new(64);
store
.enqueue(0, payload(b'x', 3), Duration::from_secs(60))
.unwrap();
let dropped = store.expire_now(Instant::now());
assert_eq!(dropped, 0);
assert_eq!(store.total_len(), 1);
}
#[test]
fn stats_track_capacity_and_bytes() {
let store = HintStore::new(1024);
store
.enqueue(0, payload(b'x', 100), Duration::from_secs(60))
.unwrap();
let s = store.stats();
assert_eq!(s.hint_count, 1);
assert_eq!(s.bytes, 100);
assert_eq!(s.max_bytes, 1024);
}
}