use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::error::TransportError;
use crate::request::{JsonRpcRequest, JsonRpcResponse};
use crate::transport::RpcTransport;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheTier {
Immutable,
SemiStable,
Volatile,
NeverCache,
}
impl CacheTier {
pub fn default_ttl(&self) -> Option<Duration> {
match self {
CacheTier::Immutable => Some(Duration::from_secs(3600)),
CacheTier::SemiStable => Some(Duration::from_secs(300)),
CacheTier::Volatile => Some(Duration::from_secs(2)),
CacheTier::NeverCache => None,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheTierResolver {
_private: (),
}
impl CacheTierResolver {
pub fn new() -> Self {
Self { _private: () }
}
pub fn tier_for(&self, method: &str, params: &[serde_json::Value]) -> CacheTier {
match method {
"eth_getTransactionByHash" | "eth_getTransactionReceipt" => CacheTier::Immutable,
"eth_getBlockByNumber" => {
if let Some(block_param) = params.first() {
if is_concrete_block_number(block_param) {
CacheTier::Immutable
} else {
CacheTier::Volatile
}
} else {
CacheTier::Volatile
}
}
"eth_getBlockByHash" => CacheTier::Immutable,
"eth_chainId"
| "net_version"
| "eth_getCode"
| "net_listening"
| "web3_clientVersion"
| "eth_protocolVersion"
| "eth_accounts" => CacheTier::SemiStable,
"eth_blockNumber"
| "eth_gasPrice"
| "eth_estimateGas"
| "eth_getBalance"
| "eth_getTransactionCount"
| "eth_call"
| "eth_feeHistory"
| "eth_maxPriorityFeePerGas"
| "eth_getStorageAt" => CacheTier::Volatile,
"eth_sendRawTransaction"
| "eth_sendTransaction"
| "eth_subscribe"
| "eth_unsubscribe"
| "eth_newFilter"
| "eth_newBlockFilter"
| "eth_newPendingTransactionFilter"
| "eth_uninstallFilter"
| "eth_getFilterChanges"
| "eth_getFilterLogs"
| "personal_sign"
| "eth_sign"
| "eth_signTransaction"
| "eth_signTypedData_v4" => CacheTier::NeverCache,
_ => CacheTier::NeverCache,
}
}
}
impl Default for CacheTierResolver {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub default_ttl: Duration,
pub max_entries: usize,
pub cacheable_methods: HashSet<String>,
pub tier_resolver: Option<CacheTierResolver>,
}
impl Default for CacheConfig {
fn default() -> Self {
let cacheable: HashSet<String> = [
"eth_chainId",
"eth_getBlockByNumber",
"eth_getCode",
"net_version",
]
.iter()
.map(|s| (*s).to_string())
.collect();
Self {
default_ttl: Duration::from_secs(60),
max_entries: 1024,
cacheable_methods: cacheable,
tier_resolver: None,
}
}
}
struct CacheEntry {
method: String,
response: JsonRpcResponse,
inserted_at: Instant,
#[allow(dead_code)]
tier: CacheTier,
block_ref: Option<u64>,
ttl: Duration,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
}
struct CacheInner {
entries: HashMap<u64, CacheEntry>,
stats: CacheStats,
}
pub struct CacheTransport {
inner: Arc<dyn RpcTransport>,
cache: Mutex<CacheInner>,
config: CacheConfig,
}
impl CacheTransport {
pub fn new(inner: Arc<dyn RpcTransport>, config: CacheConfig) -> Self {
Self {
inner,
cache: Mutex::new(CacheInner {
entries: HashMap::new(),
stats: CacheStats::default(),
}),
config,
}
}
pub async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
let (is_cacheable, tier, ttl) = self.resolve_cacheability(&req);
if !is_cacheable {
return self.inner.send(req).await;
}
let key = cache_key(&req.method, &req.params);
{
let mut inner = self.cache.lock().unwrap();
self.evict_expired(&mut inner);
let cached = inner.entries.get(&key).and_then(|entry| {
if entry.inserted_at.elapsed() < entry.ttl {
Some(entry.response.clone())
} else {
None
}
});
if let Some(response) = cached {
inner.stats.hits += 1;
tracing::debug!(method = %req.method, "cache hit");
return Ok(response);
}
inner.entries.remove(&key);
inner.stats.misses += 1;
}
let response = self.inner.send(req.clone()).await?;
if response.is_ok() {
let block_ref = extract_block_ref(&req.method, &req.params);
let mut inner = self.cache.lock().unwrap();
while inner.entries.len() >= self.config.max_entries {
self.evict_oldest(&mut inner);
}
inner.entries.insert(
key,
CacheEntry {
method: req.method.clone(),
response: response.clone(),
inserted_at: Instant::now(),
tier,
block_ref,
ttl,
},
);
tracing::debug!(method = %req.method, ?tier, "cached response");
}
Ok(response)
}
pub fn invalidate(&self) {
let mut inner = self.cache.lock().unwrap();
inner.entries.clear();
tracing::info!("cache invalidated (all entries)");
}
pub fn invalidate_method(&self, method: &str) {
let mut inner = self.cache.lock().unwrap();
inner.entries.retain(|_, entry| entry.method != method);
}
pub fn invalidate_for_reorg(&self, from_block: u64) {
let mut inner = self.cache.lock().unwrap();
let before = inner.entries.len();
inner.entries.retain(|_, entry| {
match entry.block_ref {
Some(block) => block < from_block,
None => true, }
});
let removed = before - inner.entries.len();
tracing::info!(from_block, removed, "cache invalidated for reorg");
}
pub fn stats(&self) -> CacheStats {
let inner = self.cache.lock().unwrap();
CacheStats {
hits: inner.stats.hits,
misses: inner.stats.misses,
size: inner.entries.len(),
}
}
fn resolve_cacheability(&self, req: &JsonRpcRequest) -> (bool, CacheTier, Duration) {
if let Some(ref resolver) = self.config.tier_resolver {
let tier = resolver.tier_for(&req.method, &req.params);
match tier {
CacheTier::NeverCache => (false, tier, Duration::ZERO),
_ => {
let ttl = tier.default_ttl().unwrap_or(self.config.default_ttl);
(true, tier, ttl)
}
}
} else {
let is_cacheable = self.config.cacheable_methods.contains(&req.method);
(
is_cacheable,
CacheTier::SemiStable, self.config.default_ttl,
)
}
}
fn evict_expired(&self, inner: &mut CacheInner) {
inner
.entries
.retain(|_, entry| entry.inserted_at.elapsed() < entry.ttl);
}
fn evict_oldest(&self, inner: &mut CacheInner) {
if inner.entries.is_empty() {
return;
}
let oldest_key = inner
.entries
.iter()
.min_by_key(|(_, e)| e.inserted_at)
.map(|(k, _)| *k);
if let Some(key) = oldest_key {
inner.entries.remove(&key);
}
}
}
fn cache_key(method: &str, params: &[serde_json::Value]) -> u64 {
let mut hasher = DefaultHasher::new();
method.hash(&mut hasher);
let params_str = serde_json::to_string(params).unwrap_or_default();
params_str.hash(&mut hasher);
hasher.finish()
}
fn is_concrete_block_number(value: &serde_json::Value) -> bool {
match value.as_str() {
Some(s) => {
let tags = ["latest", "pending", "earliest", "safe", "finalized"];
if tags.contains(&s) {
return false;
}
s.starts_with("0x") || s.starts_with("0X")
}
None => {
value.is_number()
}
}
}
fn extract_block_ref(method: &str, params: &[serde_json::Value]) -> Option<u64> {
match method {
"eth_getBlockByNumber" => params.first().and_then(parse_hex_block),
"eth_getTransactionByBlockNumberAndIndex" => params.first().and_then(parse_hex_block),
_ => None,
}
}
fn parse_hex_block(value: &serde_json::Value) -> Option<u64> {
let s = value.as_str()?;
let hex = s.strip_prefix("0x").or_else(|| s.strip_prefix("0X"))?;
u64::from_str_radix(hex, 16).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::request::{JsonRpcRequest, JsonRpcResponse, RpcId};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, Ordering};
struct CountingTransport {
call_count: AtomicU64,
}
impl CountingTransport {
fn new() -> Self {
Self {
call_count: AtomicU64::new(0),
}
}
fn calls(&self) -> u64 {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl RpcTransport for CountingTransport {
async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: RpcId::Number(1),
result: Some(serde_json::Value::String("0x1".into())),
error: None,
})
}
fn url(&self) -> &str {
"mock://counting"
}
}
fn default_config() -> CacheConfig {
CacheConfig {
default_ttl: Duration::from_secs(60),
max_entries: 128,
cacheable_methods: ["eth_chainId"].iter().map(|s| s.to_string()).collect(),
tier_resolver: None,
}
}
fn tiered_config() -> CacheConfig {
CacheConfig {
default_ttl: Duration::from_secs(60),
max_entries: 128,
cacheable_methods: HashSet::new(), tier_resolver: Some(CacheTierResolver::new()),
}
}
fn make_req(method: &str) -> JsonRpcRequest {
JsonRpcRequest::new(1, method, vec![])
}
fn make_req_with_params(method: &str, params: Vec<serde_json::Value>) -> JsonRpcRequest {
JsonRpcRequest::new(1, method, params)
}
#[tokio::test]
async fn cache_hit_returns_same_response() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), default_config());
let req = make_req("eth_chainId");
let r1 = cache.send(req.clone()).await.unwrap();
let r2 = cache.send(req).await.unwrap();
assert_eq!(r1.result, r2.result);
assert_eq!(transport.calls(), 1);
}
#[tokio::test]
async fn cache_miss_delegates_to_inner() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), default_config());
let _r = cache.send(make_req("eth_chainId")).await.unwrap();
assert_eq!(transport.calls(), 1);
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
assert_eq!(stats.size, 1);
}
#[tokio::test]
async fn ttl_expiry_works() {
let transport = Arc::new(CountingTransport::new());
let config = CacheConfig {
default_ttl: Duration::from_millis(50), max_entries: 128,
cacheable_methods: ["eth_chainId"].iter().map(|s| s.to_string()).collect(),
tier_resolver: None,
};
let cache = CacheTransport::new(transport.clone(), config);
let req = make_req("eth_chainId");
cache.send(req.clone()).await.unwrap();
assert_eq!(transport.calls(), 1);
tokio::time::sleep(Duration::from_millis(100)).await;
cache.send(req).await.unwrap();
assert_eq!(transport.calls(), 2);
}
#[tokio::test]
async fn non_cacheable_methods_bypass_cache() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), default_config());
let req = make_req("eth_blockNumber");
cache.send(req.clone()).await.unwrap();
cache.send(req).await.unwrap();
assert_eq!(transport.calls(), 2);
assert_eq!(cache.stats().size, 0);
}
#[tokio::test]
async fn invalidate_clears_cache() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), default_config());
cache.send(make_req("eth_chainId")).await.unwrap();
assert_eq!(cache.stats().size, 1);
cache.invalidate();
assert_eq!(cache.stats().size, 0);
cache.send(make_req("eth_chainId")).await.unwrap();
assert_eq!(transport.calls(), 2);
}
#[tokio::test]
async fn max_entries_evicts_oldest() {
let transport = Arc::new(CountingTransport::new());
let config = CacheConfig {
default_ttl: Duration::from_secs(60),
max_entries: 2,
cacheable_methods: ["eth_chainId", "eth_getCode"]
.iter()
.map(|s| s.to_string())
.collect(),
tier_resolver: None,
};
let cache = CacheTransport::new(transport.clone(), config);
cache
.send(JsonRpcRequest::new(
1,
"eth_chainId",
vec![serde_json::Value::String("a".into())],
))
.await
.unwrap();
cache
.send(JsonRpcRequest::new(
2,
"eth_chainId",
vec![serde_json::Value::String("b".into())],
))
.await
.unwrap();
assert_eq!(cache.stats().size, 2);
cache
.send(JsonRpcRequest::new(
3,
"eth_getCode",
vec![serde_json::Value::String("c".into())],
))
.await
.unwrap();
assert_eq!(cache.stats().size, 2);
}
#[tokio::test]
async fn invalidate_method_is_targeted() {
let transport = Arc::new(CountingTransport::new());
let config = CacheConfig {
default_ttl: Duration::from_secs(60),
max_entries: 128,
cacheable_methods: ["eth_chainId", "eth_getCode"]
.iter()
.map(|s| s.to_string())
.collect(),
tier_resolver: None,
};
let cache = CacheTransport::new(transport.clone(), config);
cache.send(make_req("eth_chainId")).await.unwrap();
cache.send(make_req("eth_getCode")).await.unwrap();
assert_eq!(cache.stats().size, 2);
cache.invalidate_method("eth_chainId");
assert_eq!(cache.stats().size, 1);
cache.send(make_req("eth_chainId")).await.unwrap();
assert_eq!(transport.calls(), 3); }
#[test]
fn cache_key_deterministic() {
let k1 = cache_key("eth_chainId", &[]);
let k2 = cache_key("eth_chainId", &[]);
assert_eq!(k1, k2);
let k3 = cache_key("eth_blockNumber", &[]);
assert_ne!(k1, k3);
}
#[test]
fn cache_key_differs_by_params() {
let k1 = cache_key("eth_getCode", &[serde_json::Value::String("0xabc".into())]);
let k2 = cache_key("eth_getCode", &[serde_json::Value::String("0xdef".into())]);
assert_ne!(k1, k2);
}
#[test]
fn tier_default_ttls() {
assert_eq!(
CacheTier::Immutable.default_ttl(),
Some(Duration::from_secs(3600))
);
assert_eq!(
CacheTier::SemiStable.default_ttl(),
Some(Duration::from_secs(300))
);
assert_eq!(
CacheTier::Volatile.default_ttl(),
Some(Duration::from_secs(2))
);
assert_eq!(CacheTier::NeverCache.default_ttl(), None);
}
#[test]
fn resolver_classifies_methods() {
let resolver = CacheTierResolver::new();
assert_eq!(
resolver.tier_for("eth_getTransactionReceipt", &[]),
CacheTier::Immutable
);
assert_eq!(
resolver.tier_for("eth_getTransactionByHash", &[]),
CacheTier::Immutable
);
assert_eq!(resolver.tier_for("eth_chainId", &[]), CacheTier::SemiStable);
assert_eq!(resolver.tier_for("net_version", &[]), CacheTier::SemiStable);
assert_eq!(resolver.tier_for("eth_getCode", &[]), CacheTier::SemiStable);
assert_eq!(
resolver.tier_for("eth_blockNumber", &[]),
CacheTier::Volatile
);
assert_eq!(resolver.tier_for("eth_gasPrice", &[]), CacheTier::Volatile);
assert_eq!(
resolver.tier_for("eth_sendRawTransaction", &[]),
CacheTier::NeverCache
);
assert_eq!(
resolver.tier_for("eth_subscribe", &[]),
CacheTier::NeverCache
);
}
#[tokio::test]
async fn tier_immutable_long_ttl() {
let transport = Arc::new(CountingTransport::new());
let config = CacheConfig {
default_ttl: Duration::from_millis(50), max_entries: 128,
cacheable_methods: HashSet::new(),
tier_resolver: Some(CacheTierResolver::new()),
};
let cache = CacheTransport::new(transport.clone(), config);
let req = make_req_with_params(
"eth_getTransactionReceipt",
vec![serde_json::Value::String("0xabc123def456".into())],
);
cache.send(req.clone()).await.unwrap();
assert_eq!(transport.calls(), 1);
tokio::time::sleep(Duration::from_millis(100)).await;
cache.send(req).await.unwrap();
assert_eq!(transport.calls(), 1); }
#[tokio::test]
async fn tier_volatile_short_ttl() {
let transport = Arc::new(CountingTransport::new());
let config = CacheConfig {
default_ttl: Duration::from_secs(60), max_entries: 128,
cacheable_methods: HashSet::new(),
tier_resolver: Some(CacheTierResolver::new()),
};
let cache = CacheTransport::new(transport.clone(), config);
let req = make_req("eth_gasPrice");
cache.send(req.clone()).await.unwrap();
assert_eq!(transport.calls(), 1);
tokio::time::sleep(Duration::from_millis(2100)).await;
cache.send(req).await.unwrap();
assert_eq!(transport.calls(), 2);
}
#[tokio::test]
async fn tier_never_cache_bypasses() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), tiered_config());
let req = make_req("eth_sendRawTransaction");
cache.send(req.clone()).await.unwrap();
cache.send(req).await.unwrap();
assert_eq!(transport.calls(), 2);
assert_eq!(cache.stats().size, 0);
}
#[tokio::test]
async fn reorg_invalidation_removes_affected() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), tiered_config());
for block in [100u64, 200, 300] {
let req = make_req_with_params(
"eth_getBlockByNumber",
vec![
serde_json::Value::String(format!("0x{:x}", block)),
serde_json::Value::Bool(true),
],
);
cache.send(req).await.unwrap();
}
assert_eq!(cache.stats().size, 3);
cache.invalidate_for_reorg(200);
assert_eq!(cache.stats().size, 1);
let req200 = make_req_with_params(
"eth_getBlockByNumber",
vec![
serde_json::Value::String("0xc8".into()),
serde_json::Value::Bool(true),
],
);
cache.send(req200).await.unwrap();
assert_eq!(transport.calls(), 4);
}
#[tokio::test]
async fn block_param_latest_is_volatile() {
let resolver = CacheTierResolver::new();
let tier_latest = resolver.tier_for(
"eth_getBlockByNumber",
&[
serde_json::Value::String("latest".into()),
serde_json::Value::Bool(true),
],
);
assert_eq!(tier_latest, CacheTier::Volatile);
let tier_pending = resolver.tier_for(
"eth_getBlockByNumber",
&[
serde_json::Value::String("pending".into()),
serde_json::Value::Bool(true),
],
);
assert_eq!(tier_pending, CacheTier::Volatile);
let tier_concrete = resolver.tier_for(
"eth_getBlockByNumber",
&[
serde_json::Value::String("0x10d4f".into()),
serde_json::Value::Bool(true),
],
);
assert_eq!(tier_concrete, CacheTier::Immutable);
}
#[test]
fn is_concrete_block_number_checks() {
assert!(!is_concrete_block_number(&serde_json::Value::String(
"latest".into()
)));
assert!(!is_concrete_block_number(&serde_json::Value::String(
"pending".into()
)));
assert!(!is_concrete_block_number(&serde_json::Value::String(
"earliest".into()
)));
assert!(!is_concrete_block_number(&serde_json::Value::String(
"safe".into()
)));
assert!(!is_concrete_block_number(&serde_json::Value::String(
"finalized".into()
)));
assert!(is_concrete_block_number(&serde_json::Value::String(
"0x10d4f".into()
)));
assert!(is_concrete_block_number(&serde_json::Value::String(
"0X1A".into()
)));
assert!(is_concrete_block_number(&serde_json::json!(42)));
}
#[test]
fn parse_hex_block_works() {
assert_eq!(
parse_hex_block(&serde_json::Value::String("0x64".into())),
Some(100)
);
assert_eq!(
parse_hex_block(&serde_json::Value::String("0xc8".into())),
Some(200)
);
assert_eq!(
parse_hex_block(&serde_json::Value::String("latest".into())),
None
);
assert_eq!(parse_hex_block(&serde_json::json!(42)), None);
}
#[test]
fn extract_block_ref_for_get_block() {
assert_eq!(
extract_block_ref(
"eth_getBlockByNumber",
&[serde_json::Value::String("0x64".into())]
),
Some(100)
);
assert_eq!(
extract_block_ref(
"eth_getBlockByNumber",
&[serde_json::Value::String("latest".into())]
),
None
);
assert_eq!(extract_block_ref("eth_chainId", &[]), None);
}
#[tokio::test]
async fn tiered_mode_caches_semi_stable() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), tiered_config());
let req = make_req("eth_chainId");
cache.send(req.clone()).await.unwrap();
cache.send(req).await.unwrap();
assert_eq!(transport.calls(), 1); assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().misses, 1);
}
#[tokio::test]
async fn reorg_keeps_unrelated_entries() {
let transport = Arc::new(CountingTransport::new());
let cache = CacheTransport::new(transport.clone(), tiered_config());
cache.send(make_req("eth_chainId")).await.unwrap();
let block_req = make_req_with_params(
"eth_getBlockByNumber",
vec![
serde_json::Value::String("0x64".into()),
serde_json::Value::Bool(true),
],
);
cache.send(block_req).await.unwrap();
assert_eq!(cache.stats().size, 2);
cache.invalidate_for_reorg(50);
assert_eq!(cache.stats().size, 1);
}
}