use std::cmp::{max, min};
use std::collections::HashMap;
use std::hash::Hash;
use alloy_primitives::BlockNumber;
pub trait Mergeable {
fn merge(&mut self, other: &Self);
}
#[derive(Debug, Clone, Default)]
pub struct BlockRangeCache<K, V>
where
K: Clone + Eq + Hash,
V: Mergeable + Clone,
{
cache: HashMap<(K, BlockNumber, BlockNumber), V>,
}
impl<K, V> BlockRangeCache<K, V>
where
K: Clone + Eq + Hash,
V: Mergeable + Clone,
{
pub fn get(&self, key: &K, start_block: BlockNumber, end_block: BlockNumber) -> Option<V> {
if let Some(result) = self.cache.get(&(key.clone(), start_block, end_block)) {
return Some(result.clone());
}
for ((cached_key, cached_start, cached_end), result) in &self.cache {
if cached_key == key && *cached_start <= start_block && *cached_end >= end_block {
return Some(result.clone());
}
}
None
}
fn find_overlapping(
&self,
key: &K,
start_block: BlockNumber,
end_block: BlockNumber,
) -> Vec<((K, BlockNumber, BlockNumber), &V)> {
let mut overlapping = Vec::new();
for (cache_key @ (cached_key, cached_start, cached_end), result) in &self.cache {
if cached_key == key && !(*cached_end < start_block || *cached_start > end_block) {
overlapping.push((cache_key.clone(), result));
}
}
overlapping.sort_by_key(|((_, start, _), _)| *start);
overlapping
}
pub fn insert(&mut self, key: K, start_block: BlockNumber, end_block: BlockNumber, value: V) {
let overlapping = self.find_overlapping(&key, start_block, end_block);
if overlapping.is_empty() {
self.cache.insert((key, start_block, end_block), value);
return;
}
let mut merged_value = value;
let mut min_start = start_block;
let mut max_end = end_block;
let keys_to_remove: Vec<(K, BlockNumber, BlockNumber)> =
overlapping.iter().map(|(k, _)| k.clone()).collect();
for ((_, cached_start, cached_end), cached_value) in overlapping {
min_start = min(min_start, cached_start);
max_end = max(max_end, cached_end);
merged_value.merge(cached_value);
}
for cache_key in keys_to_remove {
self.cache.remove(&cache_key);
}
self.cache.insert((key, min_start, max_end), merged_value);
}
pub fn calculate_gaps<F>(
&self,
key: &K,
start_block: BlockNumber,
end_block: BlockNumber,
create_empty: F,
) -> (Option<V>, Vec<(BlockNumber, BlockNumber)>)
where
F: FnOnce() -> V,
{
if let Some(result) = self.get(key, start_block, end_block) {
return (Some(result), vec![]);
}
let overlapping = self.find_overlapping(key, start_block, end_block);
if overlapping.is_empty() {
return (None, vec![(start_block, end_block)]);
}
let mut merged_result = create_empty();
for (_, result) in &overlapping {
merged_result.merge(result);
}
let mut covered_ranges: Vec<(BlockNumber, BlockNumber)> = overlapping
.iter()
.map(|((_, block_start, block_end), _)| (*block_start, *block_end))
.collect();
covered_ranges.sort_by_key(|(start, _)| *start);
let mut gaps = vec![];
let mut current = start_block;
for (range_start, range_end) in covered_ranges {
if current < range_start {
gaps.push((current, range_start - 1));
}
current = max(current, range_end + 1);
}
if current <= end_block {
gaps.push((current, end_block));
}
(Some(merged_result), gaps)
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn retain<F>(&mut self, mut predicate: F)
where
F: FnMut(&K, BlockNumber, BlockNumber) -> bool,
{
self.cache
.retain(|(key, start, end), _| predicate(key, *start, *end));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq, Default)]
struct TestValue {
count: usize,
total: u64,
}
impl TestValue {
fn new(count: usize, total: u64) -> Self {
Self { count, total }
}
}
impl Mergeable for TestValue {
fn merge(&mut self, other: &Self) {
self.count += other.count;
self.total += other.total;
}
}
#[test]
fn test_cache_empty_get_returns_none() {
let cache: BlockRangeCache<String, TestValue> = BlockRangeCache::default();
let key = "test".to_string();
let result = cache.get(&key, 100, 200);
assert!(result.is_none(), "Empty cache should return None");
}
#[test]
fn test_cache_exact_match() {
let mut cache = BlockRangeCache::default();
let key = "test".to_string();
let value = TestValue::new(5, 1000);
cache.insert(key.clone(), 100, 200, value.clone());
let result = cache.get(&key, 100, 200);
assert!(result.is_some(), "Should find exact match");
assert_eq!(result.unwrap(), value);
}
#[test]
fn test_cache_fully_contained_range() {
let mut cache = BlockRangeCache::default();
let key = "test".to_string();
let value = TestValue::new(5, 1000);
cache.insert(key.clone(), 50, 250, value.clone());
let result = cache.get(&key, 100, 200);
assert!(result.is_some(), "Should find contained range");
assert_eq!(result.unwrap(), value);
}
#[test]
fn test_cache_partial_overlap_returns_none() {
let mut cache = BlockRangeCache::default();
let key = "test".to_string();
cache.insert(key.clone(), 100, 200, TestValue::new(5, 1000));
let result = cache.get(&key, 150, 250);
assert!(
result.is_none(),
"Partial overlap should return None from get()"
);
}
#[test]
fn test_insert_with_overlap_merges() {
let mut cache = BlockRangeCache::default();
let key = "test".to_string();
cache.insert(key.clone(), 100, 200, TestValue::new(5, 500));
cache.insert(key.clone(), 150, 250, TestValue::new(3, 800));
let result = cache.get(&key, 100, 250);
assert!(result.is_some(), "Should find merged range");
let merged = result.unwrap();
assert_eq!(merged.count, 8); assert_eq!(merged.total, 1300); }
#[test]
fn test_calculate_gaps_empty_cache() {
let cache: BlockRangeCache<String, TestValue> = BlockRangeCache::default();
let key = "test".to_string();
let (result, gaps) = cache.calculate_gaps(&key, 100, 200, || TestValue::new(0, 0));
assert!(result.is_none(), "Empty cache should return None result");
assert_eq!(gaps.len(), 1, "Should have one gap covering entire range");
assert_eq!(gaps[0], (100, 200));
}
#[test]
fn test_calculate_gaps_fully_cached() {
let mut cache = BlockRangeCache::default();
let key = "test".to_string();
cache.insert(key.clone(), 50, 250, TestValue::new(10, 1000));
let (result, gaps) = cache.calculate_gaps(&key, 100, 200, || TestValue::new(0, 0));
assert!(result.is_some(), "Should return cached result");
assert_eq!(gaps.len(), 0, "No gaps when fully cached");
}
#[test]
fn test_calculate_gaps_middle_gap() {
let mut cache = BlockRangeCache::default();
let key = "test".to_string();
cache.insert(key.clone(), 100, 150, TestValue::new(5, 500));
cache.insert(key.clone(), 200, 250, TestValue::new(8, 800));
let (result, gaps) = cache.calculate_gaps(&key, 100, 250, || TestValue::new(0, 0));
assert!(result.is_some(), "Should merge cached data");
assert_eq!(gaps.len(), 1, "Should have one gap in middle");
assert_eq!(gaps[0], (151, 199), "Gap should be from 151 to 199");
let merged = result.unwrap();
assert_eq!(merged.count, 13); assert_eq!(merged.total, 1300); }
#[test]
fn test_len_and_is_empty() {
let mut cache: BlockRangeCache<String, TestValue> = BlockRangeCache::default();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
cache.insert("test".to_string(), 100, 200, TestValue::new(1, 100));
assert_eq!(cache.len(), 1);
assert!(!cache.is_empty());
}
#[test]
fn test_retain() {
let mut cache = BlockRangeCache::default();
let key1 = "keep".to_string();
let key2 = "remove".to_string();
cache.insert(key1.clone(), 100, 200, TestValue::new(1, 100));
cache.insert(key2.clone(), 300, 400, TestValue::new(2, 200));
cache.retain(|key, _start, _end| !key.contains("remove"));
assert_eq!(cache.len(), 1);
assert!(cache.get(&key1, 100, 200).is_some());
assert!(cache.get(&key2, 300, 400).is_none());
}
}