newton-core 0.4.16

newton protocol core sdk
use alloy::primitives::{Address, Bytes, B256};
use moka::future::Cache;
use std::time::Duration;

/// Cached snapshot of a policy client's resolved state.
/// All fields are pinned to the policyId generation — content-addressed
/// and consensus-safe by construction.
#[derive(Clone, Debug)]
pub struct PolicySnapshot {
    /// Resolved policy contract address for this client.
    pub policy_address: Address,
    /// Content-addressed policy generation hash.
    pub policy_id: B256,
    /// ABI-encoded policy parameters.
    pub policy_params: Bytes,
    /// Block-based expiration window for attestations.
    pub expire_after: u32,
}

/// Configuration for the policy contract cache.
#[derive(Clone, Debug)]
pub struct PolicyContractCacheConfig {
    /// Max entries for policy snapshot cache (LRU eviction, no TTL).
    pub snapshot_max_capacity: u64,
    /// TTL for version caches (changes only on proxy upgrade).
    pub version_ttl_secs: u64,
    /// Max entries for each version cache.
    pub version_max_capacity: u64,
}

impl Default for PolicyContractCacheConfig {
    fn default() -> Self {
        Self {
            snapshot_max_capacity: 1024,
            version_ttl_secs: 600,
            version_max_capacity: 512,
        }
    }
}

/// Shared cache for on-chain policy contract reads.
///
/// Three internal caches:
/// - `policy_snapshot`: `(chain_id, policy_id)` -> `PolicySnapshot` (no TTL, LRU only)
/// - `policy_version`: `(chain_id, policy_address)` -> version string (TTL-based)
/// - `policy_data_version`: `(chain_id, policy_data_address)` -> version string (TTL-based)
///
/// The snapshot cache is consensus-safe: `policyId` is a keccak256 content hash that
/// includes all config fields plus `block.timestamp`, so it rotates on any config
/// change. Entries are correct forever for their generation — no TTL needed.
///
/// Version caches use TTL (default 600s) because versions change only on proxy
/// upgrade, and a stale version just means an extra RPC call on the next TTL window.
///
/// Services share this via `Arc<PolicyContractCache>`.
pub struct PolicyContractCache {
    policy_snapshot: Cache<(u64, B256), PolicySnapshot>,
    policy_version: Cache<(u64, Address), String>,
    policy_data_version: Cache<(u64, Address), String>,
    /// Per-policy-client privacy classification from Rego source scan.
    /// TTL-based (same as version caches) because a policy upgrade may change
    /// whether the Rego reads privacy namespaces. Content-addressed policyCid
    /// makes the result immutable per deployment, but the policy_client address
    /// can point to a new policy after `setPolicy`.
    privacy_requires: Cache<(u64, Address), bool>,
}

impl std::fmt::Debug for PolicyContractCache {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PolicyContractCache")
            .field("snapshot_entry_count", &self.policy_snapshot.entry_count())
            .field("policy_version_entry_count", &self.policy_version.entry_count())
            .field(
                "policy_data_version_entry_count",
                &self.policy_data_version.entry_count(),
            )
            .field("privacy_requires_entry_count", &self.privacy_requires.entry_count())
            .finish()
    }
}

impl Default for PolicyContractCache {
    fn default() -> Self {
        Self::new(PolicyContractCacheConfig::default())
    }
}

impl PolicyContractCache {
    /// Create a new cache with the given configuration.
    pub fn new(config: PolicyContractCacheConfig) -> Self {
        Self {
            policy_snapshot: Cache::builder().max_capacity(config.snapshot_max_capacity).build(),
            policy_version: Cache::builder()
                .time_to_live(Duration::from_secs(config.version_ttl_secs))
                .max_capacity(config.version_max_capacity)
                .build(),
            policy_data_version: Cache::builder()
                .time_to_live(Duration::from_secs(config.version_ttl_secs))
                .max_capacity(config.version_max_capacity)
                .build(),
            privacy_requires: Cache::builder()
                .time_to_live(Duration::from_secs(config.version_ttl_secs))
                .max_capacity(config.version_max_capacity)
                .build(),
        }
    }

    // -- Snapshot cache (consensus-safe, no TTL) --

    /// Look up a cached policy snapshot by `(chain_id, policy_id)`.
    pub async fn get_snapshot(&self, chain_id: u64, policy_id: B256) -> Option<PolicySnapshot> {
        self.policy_snapshot.get(&(chain_id, policy_id)).await
    }

    /// Insert a policy snapshot keyed by `(chain_id, policy_id)`.
    pub async fn insert_snapshot(&self, chain_id: u64, policy_id: B256, snapshot: PolicySnapshot) {
        self.policy_snapshot.insert((chain_id, policy_id), snapshot).await;
    }

    /// Get or populate a policy snapshot, coalescing concurrent lookups for the same key.
    pub async fn get_or_try_insert_snapshot<F, E>(
        &self,
        chain_id: u64,
        policy_id: B256,
        init: F,
    ) -> Result<PolicySnapshot, std::sync::Arc<E>>
    where
        F: std::future::Future<Output = Result<PolicySnapshot, E>>,
        E: Send + Sync + 'static,
    {
        self.policy_snapshot.try_get_with((chain_id, policy_id), init).await
    }

    // -- Policy version cache (TTL-based) --

    /// Look up a cached policy implementation version.
    pub async fn get_policy_version(&self, chain_id: u64, policy_address: Address) -> Option<String> {
        self.policy_version.get(&(chain_id, policy_address)).await
    }

    /// Cache a policy implementation version.
    pub async fn insert_policy_version(&self, chain_id: u64, policy_address: Address, version: String) {
        self.policy_version.insert((chain_id, policy_address), version).await;
    }

    /// Get or populate a policy version, coalescing concurrent lookups.
    pub async fn get_or_try_insert_policy_version<F, E>(
        &self,
        chain_id: u64,
        policy_address: Address,
        init: F,
    ) -> Result<String, std::sync::Arc<E>>
    where
        F: std::future::Future<Output = Result<String, E>>,
        E: Send + Sync + 'static,
    {
        self.policy_version.try_get_with((chain_id, policy_address), init).await
    }

    // -- Policy data version cache (TTL-based) --

    /// Look up a cached policy data implementation version.
    pub async fn get_policy_data_version(&self, chain_id: u64, policy_data_address: Address) -> Option<String> {
        self.policy_data_version.get(&(chain_id, policy_data_address)).await
    }

    /// Cache a policy data implementation version.
    pub async fn insert_policy_data_version(&self, chain_id: u64, policy_data_address: Address, version: String) {
        self.policy_data_version
            .insert((chain_id, policy_data_address), version)
            .await;
    }

    /// Get or populate a policy data version, coalescing concurrent lookups.
    pub async fn get_or_try_insert_policy_data_version<F, E>(
        &self,
        chain_id: u64,
        policy_data_address: Address,
        init: F,
    ) -> Result<String, std::sync::Arc<E>>
    where
        F: std::future::Future<Output = Result<String, E>>,
        E: Send + Sync + 'static,
    {
        self.policy_data_version
            .try_get_with((chain_id, policy_data_address), init)
            .await
    }

    // -- Privacy detection cache (TTL-based) --

    /// Look up a cached privacy classification for a policy client.
    pub async fn get_privacy_requires(&self, chain_id: u64, policy_client: Address) -> Option<bool> {
        self.privacy_requires.get(&(chain_id, policy_client)).await
    }

    /// Cache a privacy classification for a policy client.
    pub async fn insert_privacy_requires(&self, chain_id: u64, policy_client: Address, requires: bool) {
        self.privacy_requires.insert((chain_id, policy_client), requires).await;
    }

    /// Get or populate a privacy classification, coalescing concurrent lookups.
    pub async fn get_or_try_insert_privacy_requires<F, E>(
        &self,
        chain_id: u64,
        policy_client: Address,
        init: F,
    ) -> Result<bool, std::sync::Arc<E>>
    where
        F: std::future::Future<Output = Result<bool, E>>,
        E: Send + Sync + 'static,
    {
        self.privacy_requires
            .try_get_with((chain_id, policy_client), init)
            .await
    }

    // -- Invalidation --

    /// Remove a snapshot entry (e.g., on policy rotation).
    pub async fn invalidate_snapshot(&self, chain_id: u64, policy_id: B256) {
        self.policy_snapshot.invalidate(&(chain_id, policy_id)).await;
    }

    /// Remove a policy version entry (e.g., on proxy upgrade).
    pub async fn invalidate_policy_version(&self, chain_id: u64, policy_address: Address) {
        self.policy_version.invalidate(&(chain_id, policy_address)).await;
    }

    /// Remove a policy data version entry (e.g., on proxy upgrade).
    pub async fn invalidate_policy_data_version(&self, chain_id: u64, policy_data_address: Address) {
        self.policy_data_version
            .invalidate(&(chain_id, policy_data_address))
            .await;
    }

    /// Remove a privacy classification entry (e.g., on policy upgrade changing Rego source).
    pub async fn invalidate_privacy_requires(&self, chain_id: u64, policy_client: Address) {
        self.privacy_requires.invalidate(&(chain_id, policy_client)).await;
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use alloy::primitives::{Address, Bytes, B256};
    use std::time::Duration;

    fn test_config() -> PolicyContractCacheConfig {
        PolicyContractCacheConfig {
            snapshot_max_capacity: 10,
            version_ttl_secs: 1,
            version_max_capacity: 10,
        }
    }

    fn test_snapshot() -> PolicySnapshot {
        PolicySnapshot {
            policy_address: Address::repeat_byte(0xAA),
            policy_id: B256::repeat_byte(0xBB),
            policy_params: Bytes::from_static(b"test_params"),
            expire_after: 100,
        }
    }

    #[tokio::test]
    async fn test_snapshot_cache_miss_then_hit() {
        let cache = PolicyContractCache::new(test_config());
        let chain_id = 1u64;
        let policy_id = B256::repeat_byte(0x01);

        assert!(cache.get_snapshot(chain_id, policy_id).await.is_none());

        let snapshot = test_snapshot();
        cache.insert_snapshot(chain_id, policy_id, snapshot.clone()).await;

        let cached = cache.get_snapshot(chain_id, policy_id).await.unwrap();
        assert_eq!(cached.policy_address, snapshot.policy_address);
        assert_eq!(cached.policy_id, snapshot.policy_id);
        assert_eq!(cached.expire_after, snapshot.expire_after);
    }

    #[tokio::test]
    async fn test_snapshot_different_chain_ids_are_separate() {
        let cache = PolicyContractCache::new(test_config());
        let policy_id = B256::repeat_byte(0x01);

        cache.insert_snapshot(1, policy_id, test_snapshot()).await;

        assert!(cache.get_snapshot(1, policy_id).await.is_some());
        assert!(cache.get_snapshot(2, policy_id).await.is_none());
    }

    #[tokio::test]
    async fn test_snapshot_different_policy_ids_are_separate() {
        let cache = PolicyContractCache::new(test_config());
        let chain_id = 1u64;
        let id_a = B256::repeat_byte(0x01);
        let id_b = B256::repeat_byte(0x02);

        cache.insert_snapshot(chain_id, id_a, test_snapshot()).await;

        assert!(cache.get_snapshot(chain_id, id_a).await.is_some());
        assert!(cache.get_snapshot(chain_id, id_b).await.is_none());
    }

    #[tokio::test]
    async fn test_version_cache_miss_then_hit() {
        let cache = PolicyContractCache::new(test_config());
        let chain_id = 1u64;
        let addr = Address::repeat_byte(0x01);

        assert!(cache.get_policy_version(chain_id, addr).await.is_none());

        cache.insert_policy_version(chain_id, addr, "0.3.0".to_string()).await;

        let cached = cache.get_policy_version(chain_id, addr).await.unwrap();
        assert_eq!(cached, "0.3.0");
    }

    #[tokio::test]
    async fn test_version_ttl_expiry() {
        let cache = PolicyContractCache::new(test_config()); // 1 second TTL
        let chain_id = 1u64;
        let addr = Address::repeat_byte(0x01);

        cache.insert_policy_version(chain_id, addr, "0.3.0".to_string()).await;
        assert!(cache.get_policy_version(chain_id, addr).await.is_some());

        tokio::time::sleep(Duration::from_millis(1100)).await;

        assert!(cache.get_policy_version(chain_id, addr).await.is_none());
    }

    #[tokio::test]
    async fn test_policy_data_version_cache() {
        let cache = PolicyContractCache::new(test_config());
        let chain_id = 1u64;
        let addr = Address::repeat_byte(0x01);

        assert!(cache.get_policy_data_version(chain_id, addr).await.is_none());

        cache
            .insert_policy_data_version(chain_id, addr, "0.3.0".to_string())
            .await;

        let cached = cache.get_policy_data_version(chain_id, addr).await.unwrap();
        assert_eq!(cached, "0.3.0");
    }

    #[tokio::test]
    async fn test_snapshot_invalidation() {
        let cache = PolicyContractCache::new(test_config());
        let chain_id = 1u64;
        let policy_id = B256::repeat_byte(0x01);

        cache.insert_snapshot(chain_id, policy_id, test_snapshot()).await;
        assert!(cache.get_snapshot(chain_id, policy_id).await.is_some());

        cache.invalidate_snapshot(chain_id, policy_id).await;
        assert!(cache.get_snapshot(chain_id, policy_id).await.is_none());
    }

    #[tokio::test]
    async fn test_snapshot_no_ttl_persists() {
        let cache = PolicyContractCache::new(test_config()); // 1s TTL only affects version caches
        let chain_id = 1u64;
        let policy_id = B256::repeat_byte(0x01);

        cache.insert_snapshot(chain_id, policy_id, test_snapshot()).await;

        tokio::time::sleep(Duration::from_millis(1100)).await;

        // Snapshot has no TTL — should still be there (only LRU evicts)
        assert!(cache.get_snapshot(chain_id, policy_id).await.is_some());
    }

    #[tokio::test]
    async fn test_default_config() {
        let config = PolicyContractCacheConfig::default();
        assert_eq!(config.snapshot_max_capacity, 1024);
        assert_eq!(config.version_ttl_secs, 600);
        assert_eq!(config.version_max_capacity, 512);
    }
}