use crate::{ChainType, Error, Result, TransactionRequest};
use chrono::{DateTime, Datelike, Timelike, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PolicyDecision {
Approve,
Reject { reason: String },
RequireAdditionalApproval { reason: String },
}
impl PolicyDecision {
pub fn is_approved(&self) -> bool {
matches!(self, PolicyDecision::Approve)
}
pub fn requires_additional_approval(&self) -> bool {
matches!(self, PolicyDecision::RequireAdditionalApproval { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpendingLimits {
pub per_transaction: Option<u128>,
pub daily: Option<u128>,
pub weekly: Option<u128>,
pub currency: String,
}
impl Default for SpendingLimits {
fn default() -> Self {
Self {
per_transaction: None,
daily: None,
weekly: None,
currency: "ETH".to_string(),
}
}
}
impl SpendingLimits {
pub fn with_per_tx(amount: u128, currency: impl Into<String>) -> Self {
Self {
per_transaction: Some(amount),
daily: None,
weekly: None,
currency: currency.into(),
}
}
pub fn daily(mut self, amount: u128) -> Self {
self.daily = Some(amount);
self
}
pub fn weekly(mut self, amount: u128) -> Self {
self.weekly = Some(amount);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeBounds {
pub start_hour: u8,
pub end_hour: u8,
pub allowed_days: Vec<u8>,
}
impl Default for TimeBounds {
fn default() -> Self {
Self {
start_hour: 0,
end_hour: 24,
allowed_days: vec![0, 1, 2, 3, 4, 5, 6], }
}
}
impl TimeBounds {
pub fn business_hours() -> Self {
Self {
start_hour: 9,
end_hour: 17,
allowed_days: vec![1, 2, 3, 4, 5], }
}
pub fn is_allowed(&self, timestamp: DateTime<Utc>) -> bool {
let hour = timestamp.hour() as u8;
let day = timestamp.weekday().num_days_from_sunday() as u8;
let hour_ok = if self.start_hour <= self.end_hour {
hour >= self.start_hour && hour < self.end_hour
} else {
hour >= self.start_hour || hour < self.end_hour
};
hour_ok && self.allowed_days.contains(&day)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContractRestriction {
pub allowed_contracts: HashSet<String>,
pub allowed_selectors: HashSet<String>,
pub blocked_selectors: HashSet<String>,
}
impl Default for ContractRestriction {
fn default() -> Self {
Self {
allowed_contracts: HashSet::new(),
allowed_selectors: HashSet::new(),
blocked_selectors: HashSet::new(),
}
}
}
impl ContractRestriction {
pub fn allow_contract(mut self, address: impl Into<String>) -> Self {
self.allowed_contracts.insert(address.into().to_lowercase());
self
}
pub fn allow_selector(mut self, selector: impl Into<String>) -> Self {
self.allowed_selectors
.insert(selector.into().to_lowercase());
self
}
pub fn block_selector(mut self, selector: impl Into<String>) -> Self {
self.blocked_selectors
.insert(selector.into().to_lowercase());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyConfig {
pub spending_limits: HashMap<ChainType, SpendingLimits>,
pub whitelist: Option<HashSet<String>>,
pub blacklist: HashSet<String>,
pub time_bounds: Option<TimeBounds>,
pub contract_restrictions: Option<ContractRestriction>,
pub additional_approval_threshold: Option<u128>,
pub max_pending_requests: usize,
pub enabled: bool,
}
impl Default for PolicyConfig {
fn default() -> Self {
Self {
spending_limits: HashMap::new(),
whitelist: None,
blacklist: HashSet::new(),
time_bounds: None,
contract_restrictions: None,
additional_approval_threshold: None,
max_pending_requests: 10,
enabled: true,
}
}
}
impl PolicyConfig {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn with_spending_limits(mut self, chain: ChainType, limits: SpendingLimits) -> Self {
self.spending_limits.insert(chain, limits);
self
}
pub fn with_per_tx_limit(mut self, amount: u128, currency: impl Into<String>) -> Self {
let limits = self
.spending_limits
.entry(ChainType::Evm)
.or_insert_with(SpendingLimits::default);
limits.per_transaction = Some(amount);
limits.currency = currency.into();
self
}
pub fn with_daily_limit(mut self, amount: u128) -> Self {
let limits = self
.spending_limits
.entry(ChainType::Evm)
.or_insert_with(SpendingLimits::default);
limits.daily = Some(amount);
self
}
pub fn with_weekly_limit(mut self, amount: u128) -> Self {
let limits = self
.spending_limits
.entry(ChainType::Evm)
.or_insert_with(SpendingLimits::default);
limits.weekly = Some(amount);
self
}
pub fn with_whitelist(mut self, addresses: Vec<String>) -> Self {
self.whitelist = Some(addresses.into_iter().map(|a| a.to_lowercase()).collect());
self
}
pub fn with_blacklist(mut self, addresses: Vec<String>) -> Self {
self.blacklist = addresses.into_iter().map(|a| a.to_lowercase()).collect();
self
}
pub fn with_time_bounds(mut self, bounds: TimeBounds) -> Self {
self.time_bounds = Some(bounds);
self
}
pub fn with_contract_restrictions(mut self, restrictions: ContractRestriction) -> Self {
self.contract_restrictions = Some(restrictions);
self
}
pub fn with_additional_approval_threshold(mut self, amount: u128) -> Self {
self.additional_approval_threshold = Some(amount);
self
}
}
#[derive(Debug, Default)]
struct SpendingTracker {
daily: HashMap<String, u128>,
weekly: HashMap<String, u128>,
}
impl SpendingTracker {
fn new() -> Self {
Self::default()
}
fn get_daily_spent(&self, date: &str) -> u128 {
*self.daily.get(date).unwrap_or(&0)
}
fn get_weekly_spent(&self, week: &str) -> u128 {
*self.weekly.get(week).unwrap_or(&0)
}
fn record_spending(&mut self, date: &str, week: &str, amount: u128) {
*self.daily.entry(date.to_string()).or_insert(0) += amount;
*self.weekly.entry(week.to_string()).or_insert(0) += amount;
}
fn cleanup_old_entries(&mut self, current_date: &str, current_week: &str) {
self.daily.retain(|k, _| k == current_date);
self.weekly.retain(|k, _| k == current_week);
}
}
#[derive(Debug)]
pub struct PolicyEngine {
config: PolicyConfig,
spending: Arc<RwLock<HashMap<ChainType, SpendingTracker>>>,
}
impl PolicyEngine {
pub fn new(config: PolicyConfig) -> Self {
Self {
config,
spending: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn config(&self) -> &PolicyConfig {
&self.config
}
pub fn update_config(&mut self, config: PolicyConfig) {
self.config = config;
}
pub fn evaluate(&self, tx: &TransactionRequest) -> Result<PolicyDecision> {
if !self.config.enabled {
return Ok(PolicyDecision::Approve);
}
if self.config.blacklist.contains(&tx.to.to_lowercase()) {
return Ok(PolicyDecision::Reject {
reason: format!("Address {} is blacklisted", tx.to),
});
}
if let Some(ref whitelist) = self.config.whitelist {
if !whitelist.contains(&tx.to.to_lowercase()) {
return Ok(PolicyDecision::Reject {
reason: format!("Address {} is not whitelisted", tx.to),
});
}
}
if let Some(ref bounds) = self.config.time_bounds {
let now = Utc::now();
if !bounds.is_allowed(now) {
return Ok(PolicyDecision::Reject {
reason: format!(
"Transaction outside allowed time window ({}:00-{}:00 UTC)",
bounds.start_hour, bounds.end_hour
),
});
}
}
if tx.is_contract_call() {
if let Some(ref restrictions) = self.config.contract_restrictions {
if !restrictions.allowed_contracts.is_empty()
&& !restrictions
.allowed_contracts
.contains(&tx.to.to_lowercase())
{
return Ok(PolicyDecision::Reject {
reason: format!("Contract {} is not in allowed list", tx.to),
});
}
if let Some(selector) = tx.function_selector() {
let selector_hex = hex::encode(selector);
if restrictions.blocked_selectors.contains(&selector_hex) {
return Ok(PolicyDecision::Reject {
reason: format!("Function selector 0x{} is blocked", selector_hex),
});
}
if !restrictions.allowed_selectors.is_empty()
&& !restrictions.allowed_selectors.contains(&selector_hex)
{
return Ok(PolicyDecision::Reject {
reason: format!(
"Function selector 0x{} is not in allowed list",
selector_hex
),
});
}
}
}
}
let value = self.parse_value(&tx.value)?;
if let Some(limits) = self.config.spending_limits.get(&tx.chain) {
if let Some(per_tx) = limits.per_transaction {
if value > per_tx {
return Ok(PolicyDecision::Reject {
reason: format!(
"Transaction value {} exceeds per-transaction limit {}",
tx.value, per_tx
),
});
}
}
let now = Utc::now();
let date_key = now.format("%Y-%m-%d").to_string();
let week_key = now.format("%Y-W%W").to_string();
let spending = self.spending.read();
if let Some(tracker) = spending.get(&tx.chain) {
if let Some(daily_limit) = limits.daily {
let spent = tracker.get_daily_spent(&date_key);
if spent + value > daily_limit {
return Ok(PolicyDecision::Reject {
reason: format!(
"Transaction would exceed daily limit of {} {} (already spent: {})",
daily_limit, limits.currency, spent
),
});
}
}
if let Some(weekly_limit) = limits.weekly {
let spent = tracker.get_weekly_spent(&week_key);
if spent + value > weekly_limit {
return Ok(PolicyDecision::Reject {
reason: format!(
"Transaction would exceed weekly limit of {} {} (already spent: {})",
weekly_limit, limits.currency, spent
),
});
}
}
}
}
if let Some(threshold) = self.config.additional_approval_threshold {
if value > threshold {
return Ok(PolicyDecision::RequireAdditionalApproval {
reason: format!(
"Transaction value {} exceeds additional approval threshold {}",
tx.value, threshold
),
});
}
}
Ok(PolicyDecision::Approve)
}
pub fn record_transaction(&self, tx: &TransactionRequest) -> Result<()> {
let value = self.parse_value(&tx.value)?;
let now = Utc::now();
let date_key = now.format("%Y-%m-%d").to_string();
let week_key = now.format("%Y-W%W").to_string();
let mut spending = self.spending.write();
let tracker = spending
.entry(tx.chain)
.or_insert_with(SpendingTracker::new);
tracker.cleanup_old_entries(&date_key, &week_key);
tracker.record_spending(&date_key, &week_key, value);
Ok(())
}
fn parse_value(&self, value: &str) -> Result<u128> {
if value.contains('.') {
let parts: Vec<&str> = value.split('.').collect();
if parts.len() != 2 {
return Err(Error::PolicyViolation(format!(
"Invalid value format: {}",
value
)));
}
let whole: u128 = parts[0]
.parse()
.map_err(|_| Error::PolicyViolation(format!("Invalid value: {}", value)))?;
let mut decimal_str = parts[1].to_string();
while decimal_str.len() < 18 {
decimal_str.push('0');
}
decimal_str.truncate(18);
let decimal: u128 = decimal_str
.parse()
.map_err(|_| Error::PolicyViolation(format!("Invalid value: {}", value)))?;
Ok(whole * 10u128.pow(18) + decimal)
} else {
value
.parse()
.map_err(|_| Error::PolicyViolation(format!("Invalid value: {}", value)))
}
}
pub fn daily_spending(&self, chain: ChainType) -> u128 {
let date_key = Utc::now().format("%Y-%m-%d").to_string();
let spending = self.spending.read();
spending
.get(&chain)
.map(|t| t.get_daily_spent(&date_key))
.unwrap_or(0)
}
pub fn weekly_spending(&self, chain: ChainType) -> u128 {
let week_key = Utc::now().format("%Y-W%W").to_string();
let spending = self.spending.read();
spending
.get(&chain)
.map(|t| t.get_weekly_spent(&week_key))
.unwrap_or(0)
}
pub fn reset_spending(&self) {
let mut spending = self.spending.write();
spending.clear();
}
}
#[derive(Default)]
pub struct PolicyBuilder {
config: PolicyConfig,
}
impl PolicyBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn spending_limits(mut self, chain: ChainType, limits: SpendingLimits) -> Self {
self.config.spending_limits.insert(chain, limits);
self
}
pub fn whitelist(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
let set: HashSet<String> = addresses
.into_iter()
.map(|a| a.into().to_lowercase())
.collect();
self.config.whitelist = Some(set);
self
}
pub fn blacklist(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.config.blacklist = addresses
.into_iter()
.map(|a| a.into().to_lowercase())
.collect();
self
}
pub fn time_bounds(mut self, bounds: TimeBounds) -> Self {
self.config.time_bounds = Some(bounds);
self
}
pub fn contract_restrictions(mut self, restrictions: ContractRestriction) -> Self {
self.config.contract_restrictions = Some(restrictions);
self
}
pub fn additional_approval_threshold(mut self, amount: u128) -> Self {
self.config.additional_approval_threshold = Some(amount);
self
}
pub fn build(self) -> PolicyConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_approve_basic() {
let engine = PolicyEngine::new(PolicyConfig::default());
let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
let decision = engine.evaluate(&tx).unwrap();
assert!(decision.is_approved());
}
#[test]
fn test_policy_disabled() {
let engine = PolicyEngine::new(PolicyConfig::disabled());
let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "999999999999999999999999");
let decision = engine.evaluate(&tx).unwrap();
assert!(decision.is_approved());
}
#[test]
fn test_blacklist_rejection() {
let config = PolicyConfig::default().with_blacklist(vec!["0xBAD".to_string()]);
let engine = PolicyEngine::new(config);
let tx = TransactionRequest::new(ChainType::Evm, "0xbad", "1000");
let decision = engine.evaluate(&tx).unwrap();
assert!(!decision.is_approved());
if let PolicyDecision::Reject { reason } = decision {
assert!(reason.contains("blacklisted"));
}
}
#[test]
fn test_whitelist_rejection() {
let config = PolicyConfig::default().with_whitelist(vec!["0xGOOD".to_string()]);
let engine = PolicyEngine::new(config);
let tx = TransactionRequest::new(ChainType::Evm, "0xOTHER", "1000");
let decision = engine.evaluate(&tx).unwrap();
assert!(!decision.is_approved());
if let PolicyDecision::Reject { reason } = decision {
assert!(reason.contains("not whitelisted"));
}
}
#[test]
fn test_whitelist_approval() {
let config = PolicyConfig::default().with_whitelist(vec!["0xGOOD".to_string()]);
let engine = PolicyEngine::new(config);
let tx = TransactionRequest::new(ChainType::Evm, "0xgood", "1000");
let decision = engine.evaluate(&tx).unwrap();
assert!(decision.is_approved());
}
#[test]
fn test_per_tx_limit() {
let limits = SpendingLimits::with_per_tx(1_000_000_000_000_000_000u128, "ETH"); let config = PolicyConfig::default().with_spending_limits(ChainType::Evm, limits);
let engine = PolicyEngine::new(config);
let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "500000000000000000");
assert!(engine.evaluate(&tx).unwrap().is_approved());
let tx_over = TransactionRequest::new(ChainType::Evm, "0x1234", "2000000000000000000");
assert!(!engine.evaluate(&tx_over).unwrap().is_approved());
}
#[test]
fn test_daily_limit() {
let limits = SpendingLimits::default().daily(2_000_000_000_000_000_000u128); let config = PolicyConfig::default().with_spending_limits(ChainType::Evm, limits);
let engine = PolicyEngine::new(config);
let tx1 = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
assert!(engine.evaluate(&tx1).unwrap().is_approved());
engine.record_transaction(&tx1).unwrap();
let tx2 = TransactionRequest::new(ChainType::Evm, "0x1234", "500000000000000000");
assert!(engine.evaluate(&tx2).unwrap().is_approved());
engine.record_transaction(&tx2).unwrap();
let tx3 = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
assert!(!engine.evaluate(&tx3).unwrap().is_approved());
}
#[test]
fn test_additional_approval_threshold() {
let config = PolicyConfig::default()
.with_additional_approval_threshold(5_000_000_000_000_000_000u128); let engine = PolicyEngine::new(config);
let tx = TransactionRequest::new(ChainType::Evm, "0x1234", "1000000000000000000");
assert!(engine.evaluate(&tx).unwrap().is_approved());
let tx_over = TransactionRequest::new(ChainType::Evm, "0x1234", "10000000000000000000");
let decision = engine.evaluate(&tx_over).unwrap();
assert!(decision.requires_additional_approval());
}
#[test]
fn test_time_bounds() {
let bounds = TimeBounds::business_hours();
assert!(bounds.start_hour == 9);
assert!(bounds.end_hour == 17);
assert_eq!(bounds.allowed_days, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_contract_restrictions() {
let restrictions = ContractRestriction::default()
.allow_contract("0xUniswap")
.block_selector("a9059cbb");
let config = PolicyConfig::default().with_contract_restrictions(restrictions);
let engine = PolicyEngine::new(config);
let mut tx = TransactionRequest::new(ChainType::Evm, "0xuniswap", "0");
tx.data = Some(vec![0x12, 0x34, 0x56, 0x78]); assert!(engine.evaluate(&tx).unwrap().is_approved());
let mut tx_blocked = TransactionRequest::new(ChainType::Evm, "0xuniswap", "0");
tx_blocked.data = Some(vec![0xa9, 0x05, 0x9c, 0xbb, 0x00]); assert!(!engine.evaluate(&tx_blocked).unwrap().is_approved());
}
#[test]
fn test_policy_builder() {
let policy = PolicyBuilder::new()
.spending_limits(
ChainType::Evm,
SpendingLimits::with_per_tx(1_000_000_000_000_000_000, "ETH"),
)
.whitelist(["0x1234", "0x5678"])
.blacklist(["0xBAD"])
.time_bounds(TimeBounds::business_hours())
.additional_approval_threshold(10_000_000_000_000_000_000)
.build();
assert!(policy.whitelist.is_some());
assert!(policy.blacklist.contains("0xbad"));
assert!(policy.time_bounds.is_some());
}
#[test]
fn test_parse_decimal_value() {
let engine = PolicyEngine::new(PolicyConfig::default());
let value = engine.parse_value("1.5").unwrap();
assert_eq!(value, 1_500_000_000_000_000_000u128);
let value = engine.parse_value("0.001").unwrap();
assert_eq!(value, 1_000_000_000_000_000u128);
let value = engine.parse_value("1000000000000000000").unwrap();
assert_eq!(value, 1_000_000_000_000_000_000u128);
}
}