use alloy::primitives::{Address, Bytes, B256};
use moka::future::Cache;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct PolicySnapshot {
pub policy_address: Address,
pub policy_id: B256,
pub policy_params: Bytes,
pub expire_after: u32,
}
#[derive(Clone, Debug)]
pub struct PolicyContractCacheConfig {
pub snapshot_max_capacity: u64,
pub version_ttl_secs: u64,
pub version_max_capacity: u64,
}
impl Default for PolicyContractCacheConfig {
fn default() -> Self {
Self {
snapshot_max_capacity: 1024,
version_ttl_secs: 600,
version_max_capacity: 512,
}
}
}
pub struct PolicyContractCache {
policy_snapshot: Cache<(u64, B256), PolicySnapshot>,
policy_version: Cache<(u64, Address), String>,
policy_data_version: Cache<(u64, Address), String>,
privacy_requires: Cache<(u64, Address), bool>,
}
impl std::fmt::Debug for PolicyContractCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PolicyContractCache")
.field("snapshot_entry_count", &self.policy_snapshot.entry_count())
.field("policy_version_entry_count", &self.policy_version.entry_count())
.field(
"policy_data_version_entry_count",
&self.policy_data_version.entry_count(),
)
.field("privacy_requires_entry_count", &self.privacy_requires.entry_count())
.finish()
}
}
impl Default for PolicyContractCache {
fn default() -> Self {
Self::new(PolicyContractCacheConfig::default())
}
}
impl PolicyContractCache {
pub fn new(config: PolicyContractCacheConfig) -> Self {
Self {
policy_snapshot: Cache::builder().max_capacity(config.snapshot_max_capacity).build(),
policy_version: Cache::builder()
.time_to_live(Duration::from_secs(config.version_ttl_secs))
.max_capacity(config.version_max_capacity)
.build(),
policy_data_version: Cache::builder()
.time_to_live(Duration::from_secs(config.version_ttl_secs))
.max_capacity(config.version_max_capacity)
.build(),
privacy_requires: Cache::builder()
.time_to_live(Duration::from_secs(config.version_ttl_secs))
.max_capacity(config.version_max_capacity)
.build(),
}
}
pub async fn get_snapshot(&self, chain_id: u64, policy_id: B256) -> Option<PolicySnapshot> {
self.policy_snapshot.get(&(chain_id, policy_id)).await
}
pub async fn insert_snapshot(&self, chain_id: u64, policy_id: B256, snapshot: PolicySnapshot) {
self.policy_snapshot.insert((chain_id, policy_id), snapshot).await;
}
pub async fn get_or_try_insert_snapshot<F, E>(
&self,
chain_id: u64,
policy_id: B256,
init: F,
) -> Result<PolicySnapshot, std::sync::Arc<E>>
where
F: std::future::Future<Output = Result<PolicySnapshot, E>>,
E: Send + Sync + 'static,
{
self.policy_snapshot.try_get_with((chain_id, policy_id), init).await
}
pub async fn get_policy_version(&self, chain_id: u64, policy_address: Address) -> Option<String> {
self.policy_version.get(&(chain_id, policy_address)).await
}
pub async fn insert_policy_version(&self, chain_id: u64, policy_address: Address, version: String) {
self.policy_version.insert((chain_id, policy_address), version).await;
}
pub async fn get_or_try_insert_policy_version<F, E>(
&self,
chain_id: u64,
policy_address: Address,
init: F,
) -> Result<String, std::sync::Arc<E>>
where
F: std::future::Future<Output = Result<String, E>>,
E: Send + Sync + 'static,
{
self.policy_version.try_get_with((chain_id, policy_address), init).await
}
pub async fn get_policy_data_version(&self, chain_id: u64, policy_data_address: Address) -> Option<String> {
self.policy_data_version.get(&(chain_id, policy_data_address)).await
}
pub async fn insert_policy_data_version(&self, chain_id: u64, policy_data_address: Address, version: String) {
self.policy_data_version
.insert((chain_id, policy_data_address), version)
.await;
}
pub async fn get_or_try_insert_policy_data_version<F, E>(
&self,
chain_id: u64,
policy_data_address: Address,
init: F,
) -> Result<String, std::sync::Arc<E>>
where
F: std::future::Future<Output = Result<String, E>>,
E: Send + Sync + 'static,
{
self.policy_data_version
.try_get_with((chain_id, policy_data_address), init)
.await
}
pub async fn get_privacy_requires(&self, chain_id: u64, policy_client: Address) -> Option<bool> {
self.privacy_requires.get(&(chain_id, policy_client)).await
}
pub async fn insert_privacy_requires(&self, chain_id: u64, policy_client: Address, requires: bool) {
self.privacy_requires.insert((chain_id, policy_client), requires).await;
}
pub async fn get_or_try_insert_privacy_requires<F, E>(
&self,
chain_id: u64,
policy_client: Address,
init: F,
) -> Result<bool, std::sync::Arc<E>>
where
F: std::future::Future<Output = Result<bool, E>>,
E: Send + Sync + 'static,
{
self.privacy_requires
.try_get_with((chain_id, policy_client), init)
.await
}
pub async fn invalidate_snapshot(&self, chain_id: u64, policy_id: B256) {
self.policy_snapshot.invalidate(&(chain_id, policy_id)).await;
}
pub async fn invalidate_policy_version(&self, chain_id: u64, policy_address: Address) {
self.policy_version.invalidate(&(chain_id, policy_address)).await;
}
pub async fn invalidate_policy_data_version(&self, chain_id: u64, policy_data_address: Address) {
self.policy_data_version
.invalidate(&(chain_id, policy_data_address))
.await;
}
pub async fn invalidate_privacy_requires(&self, chain_id: u64, policy_client: Address) {
self.privacy_requires.invalidate(&(chain_id, policy_client)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy::primitives::{Address, Bytes, B256};
use std::time::Duration;
fn test_config() -> PolicyContractCacheConfig {
PolicyContractCacheConfig {
snapshot_max_capacity: 10,
version_ttl_secs: 1,
version_max_capacity: 10,
}
}
fn test_snapshot() -> PolicySnapshot {
PolicySnapshot {
policy_address: Address::repeat_byte(0xAA),
policy_id: B256::repeat_byte(0xBB),
policy_params: Bytes::from_static(b"test_params"),
expire_after: 100,
}
}
#[tokio::test]
async fn test_snapshot_cache_miss_then_hit() {
let cache = PolicyContractCache::new(test_config());
let chain_id = 1u64;
let policy_id = B256::repeat_byte(0x01);
assert!(cache.get_snapshot(chain_id, policy_id).await.is_none());
let snapshot = test_snapshot();
cache.insert_snapshot(chain_id, policy_id, snapshot.clone()).await;
let cached = cache.get_snapshot(chain_id, policy_id).await.unwrap();
assert_eq!(cached.policy_address, snapshot.policy_address);
assert_eq!(cached.policy_id, snapshot.policy_id);
assert_eq!(cached.expire_after, snapshot.expire_after);
}
#[tokio::test]
async fn test_snapshot_different_chain_ids_are_separate() {
let cache = PolicyContractCache::new(test_config());
let policy_id = B256::repeat_byte(0x01);
cache.insert_snapshot(1, policy_id, test_snapshot()).await;
assert!(cache.get_snapshot(1, policy_id).await.is_some());
assert!(cache.get_snapshot(2, policy_id).await.is_none());
}
#[tokio::test]
async fn test_snapshot_different_policy_ids_are_separate() {
let cache = PolicyContractCache::new(test_config());
let chain_id = 1u64;
let id_a = B256::repeat_byte(0x01);
let id_b = B256::repeat_byte(0x02);
cache.insert_snapshot(chain_id, id_a, test_snapshot()).await;
assert!(cache.get_snapshot(chain_id, id_a).await.is_some());
assert!(cache.get_snapshot(chain_id, id_b).await.is_none());
}
#[tokio::test]
async fn test_version_cache_miss_then_hit() {
let cache = PolicyContractCache::new(test_config());
let chain_id = 1u64;
let addr = Address::repeat_byte(0x01);
assert!(cache.get_policy_version(chain_id, addr).await.is_none());
cache.insert_policy_version(chain_id, addr, "0.3.0".to_string()).await;
let cached = cache.get_policy_version(chain_id, addr).await.unwrap();
assert_eq!(cached, "0.3.0");
}
#[tokio::test]
async fn test_version_ttl_expiry() {
let cache = PolicyContractCache::new(test_config()); let chain_id = 1u64;
let addr = Address::repeat_byte(0x01);
cache.insert_policy_version(chain_id, addr, "0.3.0".to_string()).await;
assert!(cache.get_policy_version(chain_id, addr).await.is_some());
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(cache.get_policy_version(chain_id, addr).await.is_none());
}
#[tokio::test]
async fn test_policy_data_version_cache() {
let cache = PolicyContractCache::new(test_config());
let chain_id = 1u64;
let addr = Address::repeat_byte(0x01);
assert!(cache.get_policy_data_version(chain_id, addr).await.is_none());
cache
.insert_policy_data_version(chain_id, addr, "0.3.0".to_string())
.await;
let cached = cache.get_policy_data_version(chain_id, addr).await.unwrap();
assert_eq!(cached, "0.3.0");
}
#[tokio::test]
async fn test_snapshot_invalidation() {
let cache = PolicyContractCache::new(test_config());
let chain_id = 1u64;
let policy_id = B256::repeat_byte(0x01);
cache.insert_snapshot(chain_id, policy_id, test_snapshot()).await;
assert!(cache.get_snapshot(chain_id, policy_id).await.is_some());
cache.invalidate_snapshot(chain_id, policy_id).await;
assert!(cache.get_snapshot(chain_id, policy_id).await.is_none());
}
#[tokio::test]
async fn test_snapshot_no_ttl_persists() {
let cache = PolicyContractCache::new(test_config()); let chain_id = 1u64;
let policy_id = B256::repeat_byte(0x01);
cache.insert_snapshot(chain_id, policy_id, test_snapshot()).await;
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(cache.get_snapshot(chain_id, policy_id).await.is_some());
}
#[tokio::test]
async fn test_default_config() {
let config = PolicyContractCacheConfig::default();
assert_eq!(config.snapshot_max_capacity, 1024);
assert_eq!(config.version_ttl_secs, 600);
assert_eq!(config.version_max_capacity, 512);
}
}